diff --git a/.travis.yml b/.travis.yml index 1873832..7d468ba 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,7 @@ language: go go: - - 1.6 - - 1.7.x + - 1.13.x - master script: diff --git a/Gopkg.lock b/Gopkg.lock deleted file mode 100644 index 276890d..0000000 --- a/Gopkg.lock +++ /dev/null @@ -1,15 +0,0 @@ -# This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'. - - -[[projects]] - name = "github.com/DATA-DOG/go-sqlmock" - packages = ["."] - revision = "d76b18b42f285b792bf985118980ce9eacea9d10" - version = "v1.3.0" - -[solve-meta] - analyzer-name = "dep" - analyzer-version = 1 - inputs-digest = "988d3551024f0c7c7eb0c26b50b95e2f2d9ac8d7e71f35f5e3b226afc317987f" - solver-name = "gps-cdcl" - solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml deleted file mode 100644 index 8b3f987..0000000 --- a/Gopkg.toml +++ /dev/null @@ -1,26 +0,0 @@ - -# Gopkg.toml example -# -# Refer to https://github.com/golang/dep/blob/master/docs/Gopkg.toml.md -# for detailed Gopkg.toml documentation. -# -# required = ["github.com/user/thing/cmd/thing"] -# ignored = ["github.com/user/project/pkgX", "bitbucket.org/user/project/pkgA/pkgY"] -# -# [[constraint]] -# name = "github.com/user/project" -# version = "1.0.0" -# -# [[constraint]] -# name = "github.com/user/project2" -# branch = "dev" -# source = "github.com/myfork/project2" -# -# [[override]] -# name = "github.com/x/y" -# version = "2.4.0" - - -[[constraint]] - name = "github.com/DATA-DOG/go-sqlmock" - version = "~1.3.0" diff --git a/README.md b/README.md index 7e64de5..7984ee7 100644 --- a/README.md +++ b/README.md @@ -6,49 +6,49 @@ Create MYSQL dumps in Go without the `mysqldump` CLI as a dependancy. package main import ( - "database/sql" - "fmt" + "database/sql" + "fmt" - "github.com/JamesStewy/go-mysqldump" - _ "github.com/go-sql-driver/mysql" + "github.com/JamesStewy/go-mysqldump" + "github.com/go-sql-driver/mysql" ) func main() { - // Open connection to database - username := "your-user" - password := "your-pw" - hostname := "your-hostname" - port := "your-port" - dbname := "your-db" + // Open connection to database + config := mysql.NewConfig() + config.User = "your-user" + config.Passwd = "your-pw" + config.DBName = "your-db" + config.Net = "tcp" + config.Addr = "your-hostname:your-port" dumpDir := "dumps" // you should create this directory - dumpFilenameFormat := fmt.Sprintf("%s-20060102T150405", dbname) // accepts time layout string and add .sql at the end of file - - db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%s)/%s", username, password, hostname, port, dbname)) - if err != nil { - fmt.Println("Error opening database: ", err) - return - } - - // Register database with mysqldump - dumper, err := mysqldump.Register(db, dumpDir, dumpFilenameFormat) - if err != nil { - fmt.Println("Error registering databse:", err) - return - } - - // Dump database to file - resultFilename, err := dumper.Dump() - if err != nil { - fmt.Println("Error dumping:", err) - return - } - fmt.Printf("File is saved to %s", resultFilename) - - // Close dumper and connected database - dumper.Close() + dumpFilenameFormat := fmt.Sprintf("%s-20060102T150405", config.DBName) // accepts time layout string and add .sql at the end of file + + db, err := sql.Open("mysql", config.FormatDSN()) + if err != nil { + fmt.Println("Error opening database: ", err) + return + } + + // Register database with mysqldump + dumper, err := mysqldump.Register(db, dumpDir, dumpFilenameFormat) + if err != nil { + fmt.Println("Error registering databse:", err) + return + } + + // Dump database to file + err := dumper.Dump() + if err != nil { + fmt.Println("Error dumping:", err) + return + } + fmt.Printf("File is saved to %s", dumpFilenameFormat) + + // Close dumper, connected database and file stream. + dumper.Close() } - ``` [![GoDoc](https://godoc.org/github.com/JamesStewy/go-mysqldump?status.svg)](https://godoc.org/github.com/JamesStewy/go-mysqldump) diff --git a/doc.go b/doc.go index 0851a97..187caa1 100644 --- a/doc.go +++ b/doc.go @@ -3,44 +3,52 @@ Create MYSQL dumps in Go without the 'mysqldump' CLI as a dependancy. Example -This example uses the mymysql driver (example 7 https://github.com/ziutek/mymysql) to connect to a mysql instance. +This example uses the mysql driver (https://github.com/go-sql-driver/mysql) to connect to a mysql instance. package main import ( - "database/sql" - "fmt" - "github.com/JamesStewy/go-mysqldump" - "github.com/ziutek/mymysql/godrv" - "time" + "database/sql" + "fmt" + + "github.com/JamesStewy/go-mysqldump" + "github.com/go-sql-driver/mysql" ) func main() { - // Register the mymysql driver - godrv.Register("SET NAMES utf8") - // Open connection to database - db, err := sql.Open("mymysql", "tcp:host:port*database/user/password") - if err != nil { - fmt.Println("Error opening databse:", err) + config := mysql.NewConfig() + config.User = "your-user" + config.Passwd = "your-pw" + config.DBName = "your-db" + config.Net = "tcp" + config.Addr = "your-hostname:your-port" + + dumpDir := "dumps" // you should create this directory + dumpFilenameFormat := fmt.Sprintf("%s-20060102T150405", dbname) // accepts time layout string and add .sql at the end of file + + db, err := sql.Open("mysql", config.FormatDNS()) + if err != nil { + fmt.Println("Error opening database: ", err) return } // Register database with mysqldump - dumper, err := mysqldump.Register(db, "dumps", time.ANSIC) + dumper, err := mysqldump.Register(db, dumpDir, dumpFilenameFormat) if err != nil { - fmt.Println("Error registering databse:", err) - return + fmt.Println("Error registering databse:", err) + return } // Dump database to file - err = dumper.Dump() + resultFilename, err := dumper.Dump() if err != nil { - fmt.Println("Error dumping:", err) - return + fmt.Println("Error dumping:", err) + return } + fmt.Printf("File is saved to %s", resultFilename) - // Close dumper and connected database + // Close dumper, connected database and file stream. dumper.Close() } */ diff --git a/dump.go b/dump.go index 4b6b841..b3f5f4c 100644 --- a/dump.go +++ b/dump.go @@ -1,31 +1,64 @@ package mysqldump import ( + "bytes" + "context" "database/sql" "errors" - "os" - "path" - "strings" + "fmt" + "io" + "reflect" "text/template" "time" ) +/* +Data struct to configure dump behavior + + Out: Stream to wite to + Connection: Database connection to dump + IgnoreTables: Mark sensitive tables to ignore + MaxAllowedPacket: Sets the largest packet size to use in backups + LockTables: Lock all tables for the duration of the dump +*/ +type Data struct { + Out io.Writer + Connection *sql.DB + IgnoreTables []string + MaxAllowedPacket int + LockTables bool + + tx *sql.Tx + headerTmpl *template.Template + tableTmpl *template.Template + footerTmpl *template.Template + err error +} + type table struct { - Name string - SQL string - Values string + Name string + Err error + + data *Data + rows *sql.Rows + values []interface{} } -type dump struct { +type metaData struct { DumpVersion string ServerVersion string - Tables []*table CompleteTime string } -const version = "0.2.2" +const ( + // Version of this plugin for easy reference + Version = "0.5.1" -const tmpl = `-- Go SQL Dump {{ .DumpVersion }} + defaultMaxAllowedPacket = 4194304 +) + +// takes a *metaData +const headerTmpl = `-- Go SQL Dump {{ .DumpVersion }} -- -- ------------------------------------------------------ -- Server version {{ .ServerVersion }} @@ -33,207 +66,375 @@ const tmpl = `-- Go SQL Dump {{ .DumpVersion }} /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; -/*!40101 SET NAMES utf8 */; + SET NAMES utf8mb4 ; /*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; /*!40103 SET TIME_ZONE='+00:00' */; /*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; +` +// takes a *metaData +const footerTmpl = `/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; -{{range .Tables}} +/*!40101 SET SQL_MODE=@OLD_SQL_MODE */; +/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; +/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; +/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; +/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; +/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; +/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; + +-- Dump completed on {{ .CompleteTime }} +` + +// Takes a *table +const tableTmpl = ` -- --- Table structure for table {{ .Name }} +-- Table structure for table {{ .NameEsc }} -- -DROP TABLE IF EXISTS {{ .Name }}; +DROP TABLE IF EXISTS {{ .NameEsc }}; /*!40101 SET @saved_cs_client = @@character_set_client */; -/*!40101 SET character_set_client = utf8 */; -{{ .SQL }}; + SET character_set_client = utf8mb4 ; +{{ .CreateSQL }}; /*!40101 SET character_set_client = @saved_cs_client */; + -- --- Dumping data for table {{ .Name }} +-- Dumping data for table {{ .NameEsc }} -- -LOCK TABLES {{ .Name }} WRITE; -/*!40000 ALTER TABLE {{ .Name }} DISABLE KEYS */; -{{ if .Values }} -INSERT INTO {{ .Name }} VALUES {{ .Values }}; -{{ end }} -/*!40000 ALTER TABLE {{ .Name }} ENABLE KEYS */; +LOCK TABLES {{ .NameEsc }} WRITE; +/*!40000 ALTER TABLE {{ .NameEsc }} DISABLE KEYS */; +{{ range $value := .Stream }} +{{- $value }} +{{ end -}} +/*!40000 ALTER TABLE {{ .NameEsc }} ENABLE KEYS */; UNLOCK TABLES; -{{ end }} --- Dump completed on {{ .CompleteTime }} ` -// Creates a MYSQL Dump based on the options supplied through the dumper. -func (d *Dumper) Dump() (string, error) { - name := time.Now().Format(d.format) - p := path.Join(d.dir, name+".sql") +const nullType = "NULL" - // Check dump directory - if e, _ := exists(p); e { - return p, errors.New("Dump '" + name + "' already exists.") +// Dump data using struct +func (data *Data) Dump() error { + meta := metaData{ + DumpVersion: Version, } - // Create .sql file - f, err := os.Create(p) + if data.MaxAllowedPacket == 0 { + data.MaxAllowedPacket = defaultMaxAllowedPacket + } - if err != nil { - return p, err + if err := data.getTemplates(); err != nil { + return err } - defer f.Close() + // Start the read only transaction and defer the rollback until the end + // This way the database will have the exact state it did at the begining of + // the backup and nothing can be accidentally committed + if err := data.begin(); err != nil { + return err + } + defer data.rollback() - data := dump{ - DumpVersion: version, - Tables: make([]*table, 0), + if err := meta.updateServerVersion(data); err != nil { + return err } - // Get server version - if data.ServerVersion, err = getServerVersion(d.db); err != nil { - return p, err + if err := data.headerTmpl.Execute(data.Out, meta); err != nil { + return err } - // Get tables - tables, err := getTables(d.db) + tables, err := data.getTables() if err != nil { - return p, err + return err + } + + // Lock all tables before dumping if present + if data.LockTables && len(tables) > 0 { + var b bytes.Buffer + b.WriteString("LOCK TABLES ") + for index, name := range tables { + if index != 0 { + b.WriteString(",") + } + b.WriteString("`" + name + "` READ /*!32311 LOCAL */") + } + + if _, err := data.Connection.Exec(b.String()); err != nil { + return err + } + + defer data.Connection.Exec("UNLOCK TABLES") } - // Get sql for each table for _, name := range tables { - if t, err := createTable(d.db, name); err == nil { - data.Tables = append(data.Tables, t) - } else { - return p, err + if err := data.dumpTable(name); err != nil { + return err } } + if data.err != nil { + return data.err + } - // Set complete time - data.CompleteTime = time.Now().String() + meta.CompleteTime = time.Now().String() + return data.footerTmpl.Execute(data.Out, meta) +} + +// MARK: - Private methods - // Write dump to file - t, err := template.New("mysqldump").Parse(tmpl) +// begin starts a read only transaction that will be whatever the database was +// when it was called +func (data *Data) begin() (err error) { + data.tx, err = data.Connection.BeginTx(context.Background(), &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + return +} + +// rollback cancels the transaction +func (data *Data) rollback() error { + return data.tx.Rollback() +} + +// MARK: writter methods + +func (data *Data) dumpTable(name string) error { + if data.err != nil { + return data.err + } + table := data.createTable(name) + return data.writeTable(table) +} + +func (data *Data) writeTable(table *table) error { + if err := data.tableTmpl.Execute(data.Out, table); err != nil { + return err + } + return table.Err +} + +// MARK: get methods + +// getTemplates initilaizes the templates on data from the constants in this file +func (data *Data) getTemplates() (err error) { + data.headerTmpl, err = template.New("mysqldumpHeader").Parse(headerTmpl) if err != nil { - return p, err + return } - if err = t.Execute(f, data); err != nil { - return p, err + + data.tableTmpl, err = template.New("mysqldumpTable").Parse(tableTmpl) + if err != nil { + return } - return p, nil + data.footerTmpl, err = template.New("mysqldumpTable").Parse(footerTmpl) + if err != nil { + return + } + return } -func getTables(db *sql.DB) ([]string, error) { +func (data *Data) getTables() ([]string, error) { tables := make([]string, 0) - // Get table list - rows, err := db.Query("SHOW TABLES") + rows, err := data.tx.Query("SHOW TABLES") if err != nil { return tables, err } defer rows.Close() - // Read result for rows.Next() { var table sql.NullString if err := rows.Scan(&table); err != nil { return tables, err } - tables = append(tables, table.String) + if table.Valid && !data.isIgnoredTable(table.String) { + tables = append(tables, table.String) + } } return tables, rows.Err() } -func getServerVersion(db *sql.DB) (string, error) { - var server_version sql.NullString - if err := db.QueryRow("SELECT version()").Scan(&server_version); err != nil { - return "", err +func (data *Data) isIgnoredTable(name string) bool { + for _, item := range data.IgnoreTables { + if item == name { + return true + } } - return server_version.String, nil + return false } -func createTable(db *sql.DB, name string) (*table, error) { - var err error - t := &table{Name: name} +func (meta *metaData) updateServerVersion(data *Data) (err error) { + var serverVersion sql.NullString + err = data.tx.QueryRow("SELECT version()").Scan(&serverVersion) + meta.ServerVersion = serverVersion.String + return +} - if t.SQL, err = createTableSQL(db, name); err != nil { - return nil, err - } +// MARK: create methods - if t.Values, err = createTableValues(db, name); err != nil { - return nil, err +func (data *Data) createTable(name string) *table { + return &table{ + Name: name, + data: data, } - - return t, nil } -func createTableSQL(db *sql.DB, name string) (string, error) { - // Get table creation SQL - var table_return sql.NullString - var table_sql sql.NullString - err := db.QueryRow("SHOW CREATE TABLE "+name).Scan(&table_return, &table_sql) +func (table *table) NameEsc() string { + return "`" + table.Name + "`" +} - if err != nil { +func (table *table) CreateSQL() (string, error) { + var tableReturn, tableSQL sql.NullString + if err := table.data.tx.QueryRow("SHOW CREATE TABLE "+table.NameEsc()).Scan(&tableReturn, &tableSQL); err != nil { return "", err } - if table_return.String != name { + + if tableReturn.String != table.Name { return "", errors.New("Returned table is not the same as requested table") } - return table_sql.String, nil + return tableSQL.String, nil } -func createTableValues(db *sql.DB, name string) (string, error) { - // Get Data - rows, err := db.Query("SELECT * FROM " + name) +func (table *table) Init() (err error) { + if len(table.values) != 0 { + return errors.New("can't init twice") + } + + table.rows, err = table.data.tx.Query("SELECT * FROM " + table.NameEsc()) if err != nil { - return "", err + return err } - defer rows.Close() - // Get columns - columns, err := rows.Columns() + columns, err := table.rows.Columns() if err != nil { - return "", err + return err } if len(columns) == 0 { - return "", errors.New("No columns in table " + name + ".") + return errors.New("No columns in table " + table.Name + ".") } - // Read data - data_text := make([]string, 0) - for rows.Next() { - // Init temp data storage - - //ptrs := make([]interface{}, len(columns)) - //var ptrs []interface {} = make([]*sql.NullString, len(columns)) + tt, err := table.rows.ColumnTypes() + if err != nil { + return err + } - data := make([]*sql.NullString, len(columns)) - ptrs := make([]interface{}, len(columns)) - for i, _ := range data { - ptrs[i] = &data[i] + var t reflect.Type + table.values = make([]interface{}, len(tt)) + for i, tp := range tt { + st := tp.ScanType() + if tp.DatabaseTypeName() == "BLOB" { + t = reflect.TypeOf(sql.RawBytes{}) + } else if st != nil && (st.Kind() == reflect.Int || + st.Kind() == reflect.Int8 || + st.Kind() == reflect.Int16 || + st.Kind() == reflect.Int32 || + st.Kind() == reflect.Int64) { + t = reflect.TypeOf(sql.NullInt64{}) + } else { + t = reflect.TypeOf(sql.NullString{}) } + table.values[i] = reflect.New(t).Interface() + } + return nil +} - // Read data - if err := rows.Scan(ptrs...); err != nil { - return "", err +func (table *table) Next() bool { + if table.rows == nil { + if err := table.Init(); err != nil { + table.Err = err + return false } + } + // Fallthrough + if table.rows.Next() { + if err := table.rows.Scan(table.values...); err != nil { + table.Err = err + return false + } else if err := table.rows.Err(); err != nil { + table.Err = err + return false + } + } else { + table.rows.Close() + table.rows = nil + return false + } + return true +} + +func (table *table) RowValues() string { + return table.RowBuffer().String() +} - dataStrings := make([]string, len(columns)) +func (table *table) RowBuffer() *bytes.Buffer { + var b bytes.Buffer + b.WriteString("(") - for key, value := range data { - if value != nil && value.Valid { - dataStrings[key] = "'" + value.String + "'" + for key, value := range table.values { + if key != 0 { + b.WriteString(",") + } + switch s := value.(type) { + case nil: + b.WriteString(nullType) + case *sql.NullString: + if s.Valid { + fmt.Fprintf(&b, "'%s'", sanitize(s.String)) + } else { + b.WriteString(nullType) + } + case *sql.NullInt64: + if s.Valid { + fmt.Fprintf(&b, "%d", s.Int64) + } else { + b.WriteString(nullType) + } + case *sql.RawBytes: + if len(*s) == 0 { + b.WriteString(nullType) } else { - dataStrings[key] = "null" + fmt.Fprintf(&b, "_binary '%s'", sanitize(string(*s))) } + default: + fmt.Fprintf(&b, "'%s'", value) } - - data_text = append(data_text, "("+strings.Join(dataStrings, ",")+")") } + b.WriteString(")") + + return &b +} + +func (table *table) Stream() <-chan string { + valueOut := make(chan string, 1) + go func() { + defer close(valueOut) + var insert bytes.Buffer + + for table.Next() { + b := table.RowBuffer() + // Truncate our insert if it won't fit + if insert.Len() != 0 && insert.Len()+b.Len() > table.data.MaxAllowedPacket-1 { + insert.WriteString(";") + valueOut <- insert.String() + insert.Reset() + } - return strings.Join(data_text, ","), rows.Err() + if insert.Len() == 0 { + fmt.Fprintf(&insert, "INSERT INTO %s VALUES ", table.NameEsc()) + } else { + insert.WriteString(",") + } + b.WriteTo(&insert) + } + if insert.Len() != 0 { + insert.WriteString(";") + valueOut <- insert.String() + } + }() + return valueOut } diff --git a/dump_test.go b/dump_test.go index 24dd7e0..ec0ca12 100644 --- a/dump_test.go +++ b/dump_test.go @@ -1,22 +1,35 @@ package mysqldump import ( - "io/ioutil" - "os" + "bytes" + "database/sql" "reflect" "strings" "testing" sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" ) -func TestGetTablesOk(t *testing.T) { - db, mock, err := sqlmock.New() +func getMockData() (data *Data, mock sqlmock.Sqlmock, err error) { + var db *sql.DB + db, mock, err = sqlmock.New() if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + return } + mock.ExpectBegin() - defer db.Close() + data = &Data{ + Connection: db, + } + err = data.begin() + return +} + +func TestGetTablesOk(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). AddRow("Test_Table_1"). @@ -24,30 +37,41 @@ func TestGetTablesOk(t *testing.T) { mock.ExpectQuery("^SHOW TABLES$").WillReturnRows(rows) - result, err := getTables(db) - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + result, err := data.getTables() + assert.NoError(t, err) // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - expectedResult := []string{"Test_Table_1", "Test_Table_2"} + assert.EqualValues(t, []string{"Test_Table_1", "Test_Table_2"}, result) +} - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", result, expectedResult) - } +func TestIgnoreTablesOk(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() + + rows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). + AddRow("Test_Table_1"). + AddRow("Test_Table_2") + + mock.ExpectQuery("^SHOW TABLES$").WillReturnRows(rows) + + data.IgnoreTables = []string{"Test_Table_1"} + + result, err := data.getTables() + assert.NoError(t, err) + + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") + + assert.EqualValues(t, []string{"Test_Table_2"}, result) } func TestGetTablesNil(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). AddRow("Test_Table_1"). @@ -56,76 +80,52 @@ func TestGetTablesNil(t *testing.T) { mock.ExpectQuery("^SHOW TABLES$").WillReturnRows(rows) - result, err := getTables(db) - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + result, err := data.getTables() + assert.NoError(t, err) // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - expectedResult := []string{"Test_Table_1", "", "Test_Table_3"} - - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", expectedResult, result) - } + assert.EqualValues(t, []string{"Test_Table_1", "Test_Table_3"}, result) } func TestGetServerVersionOk(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"Version()"}). AddRow("test_version") mock.ExpectQuery("^SELECT version()").WillReturnRows(rows) - result, err := getServerVersion(db) - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + meta := metaData{} - // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, meta.updateServerVersion(data), "error was not expected while updating stats") - expectedResult := "test_version" + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", expectedResult, result) - } + assert.Equal(t, "test_version", meta.ServerVersion) } func TestCreateTableSQLOk(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`s` char(60) DEFAULT NULL, PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") - mock.ExpectQuery("^SHOW CREATE TABLE Test_Table$").WillReturnRows(rows) + mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(rows) - result, err := createTableSQL(db, "Test_Table") + table := data.createTable("Test_Table") - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + result, err := table.CreateSQL() + assert.NoError(t, err) // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") expectedResult := "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`s` char(60) DEFAULT NULL, PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1" @@ -134,76 +134,108 @@ func TestCreateTableSQLOk(t *testing.T) { } } -func TestCreateTableValuesOk(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() +func TestCreateTableRowValues(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"id", "email", "name"}). AddRow(1, "test@test.de", "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2") - mock.ExpectQuery("^SELECT (.+) FROM test$").WillReturnRows(rows) + mock.ExpectQuery("^SELECT (.+) FROM `test`$").WillReturnRows(rows) - result, err := createTableValues(db, "test") - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + table := data.createTable("test") + + assert.True(t, table.Next()) + + result := table.RowValues() + assert.NoError(t, table.Err) // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - expectedResult := "('1','test@test.de','Test Name 1'),('2','test2@test.de','Test Name 2')" + assert.EqualValues(t, "('1','test@test.de','Test Name 1')", result) +} - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", expectedResult, result) - } +func TestCreateTableValuesSteam(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() + + rows := sqlmock.NewRows([]string{"id", "email", "name"}). + AddRow(1, "test@test.de", "Test Name 1"). + AddRow(2, "test2@test.de", "Test Name 2") + + mock.ExpectQuery("^SELECT (.+) FROM `test`$").WillReturnRows(rows) + + data.MaxAllowedPacket = 4096 + + table := data.createTable("test") + + s := table.Stream() + assert.EqualValues(t, "INSERT INTO `test` VALUES ('1','test@test.de','Test Name 1'),('2','test2@test.de','Test Name 2');", <-s) + + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") } -func TestCreateTableValuesNil(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } +func TestCreateTableValuesSteamSmallPackets(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() + + rows := sqlmock.NewRows([]string{"id", "email", "name"}). + AddRow(1, "test@test.de", "Test Name 1"). + AddRow(2, "test2@test.de", "Test Name 2") - defer db.Close() + mock.ExpectQuery("^SELECT (.+) FROM `test`$").WillReturnRows(rows) + + data.MaxAllowedPacket = 64 + + table := data.createTable("test") + + s := table.Stream() + assert.EqualValues(t, "INSERT INTO `test` VALUES ('1','test@test.de','Test Name 1');", <-s) + assert.EqualValues(t, "INSERT INTO `test` VALUES ('2','test2@test.de','Test Name 2');", <-s) + + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") +} + +func TestCreateTableAllValuesWithNil(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() rows := sqlmock.NewRows([]string{"id", "email", "name"}). AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2"). AddRow(3, "", "Test Name 3") - mock.ExpectQuery("^SELECT (.+) FROM test$").WillReturnRows(rows) + mock.ExpectQuery("^SELECT (.+) FROM `test`$").WillReturnRows(rows) - result, err := createTableValues(db, "test") - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) + table := data.createTable("test") + + results := make([]string, 0) + for table.Next() { + row := table.RowValues() + assert.NoError(t, table.Err) + results = append(results, row) } // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - expectedResult := "('1',null,'Test Name 1'),('2','test2@test.de','Test Name 2'),('3','','Test Name 3')" + expectedResults := []string{"('1',NULL,'Test Name 1')", "('2','test2@test.de','Test Name 2')", "('3','','Test Name 3')"} - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", expectedResult, result) - } + assert.EqualValues(t, expectedResults, results) } func TestCreateTableOk(t *testing.T) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } - - defer db.Close() + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() createTableRows := sqlmock.NewRows([]string{"Table", "Create Table"}). AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`s` char(60) DEFAULT NULL, PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") @@ -212,126 +244,97 @@ func TestCreateTableOk(t *testing.T) { AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2") - mock.ExpectQuery("^SHOW CREATE TABLE Test_Table$").WillReturnRows(createTableRows) - mock.ExpectQuery("^SELECT (.+) FROM Test_Table$").WillReturnRows(createTableValueRows) + mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(createTableRows) + mock.ExpectQuery("^SELECT (.+) FROM `Test_Table`$").WillReturnRows(createTableValueRows) - result, err := createTable(db, "Test_Table") - if err != nil { - t.Errorf("error was not expected while updating stats: %s", err) - } + var buf bytes.Buffer + data.Out = &buf + data.MaxAllowedPacket = 4096 - // we make sure that all expectations were met - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expections: %s", err) - } + assert.NoError(t, data.getTemplates()) - expectedResult := &table{ - Name: "Test_Table", - SQL: "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`s` char(60) DEFAULT NULL, PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1", - Values: "('1',null,'Test Name 1'),('2','test2@test.de','Test Name 2')", - } + table := data.createTable("Test_Table") - if !reflect.DeepEqual(result, expectedResult) { - t.Fatalf("expected %#v, got %#v", expectedResult, result) - } -} + data.writeTable(table) -func TestDumpOk(t *testing.T) { + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") - tmpFile := "/tmp/test_format.sql" - os.Remove(tmpFile) + expectedResult := ` +-- +-- Table structure for table ~Test_Table~ +-- - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } +DROP TABLE IF EXISTS ~Test_Table~; +/*!40101 SET @saved_cs_client = @@character_set_client */; + SET character_set_client = utf8mb4 ; +CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; +/*!40101 SET character_set_client = @saved_cs_client */; - defer db.Close() +-- +-- Dumping data for table ~Test_Table~ +-- - showTablesRows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). - AddRow("Test_Table") +LOCK TABLES ~Test_Table~ WRITE; +/*!40000 ALTER TABLE ~Test_Table~ DISABLE KEYS */; +INSERT INTO ~Test_Table~ VALUES ('1',NULL,'Test Name 1'),('2','test2@test.de','Test Name 2'); +/*!40000 ALTER TABLE ~Test_Table~ ENABLE KEYS */; +UNLOCK TABLES; +` + result := strings.Replace(buf.String(), "`", "~", -1) + assert.Equal(t, expectedResult, result) +} - serverVersionRows := sqlmock.NewRows([]string{"Version()"}). - AddRow("test_version") +func TestCreateTableOkSmallPackets(t *testing.T) { + data, mock, err := getMockData() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer data.Close() createTableRows := sqlmock.NewRows([]string{"Table", "Create Table"}). - AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`email` char(60) DEFAULT NULL, `name` char(60), PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") + AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`s` char(60) DEFAULT NULL, PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") createTableValueRows := sqlmock.NewRows([]string{"id", "email", "name"}). AddRow(1, nil, "Test Name 1"). AddRow(2, "test2@test.de", "Test Name 2") - mock.ExpectQuery("^SELECT version()").WillReturnRows(serverVersionRows) - mock.ExpectQuery("^SHOW TABLES$").WillReturnRows(showTablesRows) - mock.ExpectQuery("^SHOW CREATE TABLE Test_Table$").WillReturnRows(createTableRows) - mock.ExpectQuery("^SELECT (.+) FROM Test_Table$").WillReturnRows(createTableValueRows) - - dumper := &Dumper{ - db: db, - format: "test_format", - dir: "/tmp/", - } - - path, err := dumper.Dump() + mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(createTableRows) + mock.ExpectQuery("^SELECT (.+) FROM `Test_Table`$").WillReturnRows(createTableValueRows) - if path == "" { - t.Errorf("No empty path was expected while dumping the database") - } + var buf bytes.Buffer + data.Out = &buf + data.MaxAllowedPacket = 64 - if err != nil { - t.Errorf("error was not expected while dumping the database: %s", err) - } + assert.NoError(t, data.getTemplates()) - f, err := ioutil.ReadFile("/tmp/test_format.sql") - - if err != nil { - t.Errorf("error was not expected while reading the file: %s", err) - } - - result := strings.Replace(strings.Split(string(f), "-- Dump completed")[0], "`", "\\", -1) - - expected := `-- Go SQL Dump ` + version + ` --- --- ------------------------------------------------------ --- Server version test_version - -/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; -/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; -/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; -/*!40101 SET NAMES utf8 */; -/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; -/*!40103 SET TIME_ZONE='+00:00' */; -/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; -/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; -/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; -/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; + table := data.createTable("Test_Table") + data.writeTable(table) + // we make sure that all expectations were met + assert.NoError(t, mock.ExpectationsWereMet(), "there were unfulfilled expections") + expectedResult := ` -- --- Table structure for table Test_Table +-- Table structure for table ~Test_Table~ -- -DROP TABLE IF EXISTS Test_Table; +DROP TABLE IF EXISTS ~Test_Table~; /*!40101 SET @saved_cs_client = @@character_set_client */; -/*!40101 SET character_set_client = utf8 */; -CREATE TABLE 'Test_Table' (\id\ int(11) NOT NULL AUTO_INCREMENT,\email\ char(60) DEFAULT NULL, \name\ char(60), PRIMARY KEY (\id\))ENGINE=InnoDB DEFAULT CHARSET=latin1; + SET character_set_client = utf8mb4 ; +CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~s~ char(60) DEFAULT NULL, PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; /*!40101 SET character_set_client = @saved_cs_client */; + -- --- Dumping data for table Test_Table +-- Dumping data for table ~Test_Table~ -- -LOCK TABLES Test_Table WRITE; -/*!40000 ALTER TABLE Test_Table DISABLE KEYS */; - -INSERT INTO Test_Table VALUES ('1',null,'Test Name 1'),('2','test2@test.de','Test Name 2'); - -/*!40000 ALTER TABLE Test_Table ENABLE KEYS */; +LOCK TABLES ~Test_Table~ WRITE; +/*!40000 ALTER TABLE ~Test_Table~ DISABLE KEYS */; +INSERT INTO ~Test_Table~ VALUES ('1',NULL,'Test Name 1'); +INSERT INTO ~Test_Table~ VALUES ('2','test2@test.de','Test Name 2'); +/*!40000 ALTER TABLE ~Test_Table~ ENABLE KEYS */; UNLOCK TABLES; - ` - - if !reflect.DeepEqual(result, expected) { - t.Fatalf("expected %#v, got %#v", expected, result) - } + result := strings.Replace(buf.String(), "`", "~", -1) + assert.Equal(t, expectedResult, result) } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..13a96e9 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/jamf/go-mysqldump + +require ( + github.com/DATA-DOG/go-sqlmock v1.3.0 + github.com/stretchr/testify v1.4.0 +) + +go 1.13 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..9165637 --- /dev/null +++ b/go.sum @@ -0,0 +1,13 @@ +github.com/DATA-DOG/go-sqlmock v1.3.0 h1:ljjRxlddjfChBJdFKJs5LuCwCWPLaC1UZLwAo3PBBMk= +github.com/DATA-DOG/go-sqlmock v1.3.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/mysqldump.go b/mysqldump.go index aebe950..258da8c 100644 --- a/mysqldump.go +++ b/mysqldump.go @@ -3,44 +3,66 @@ package mysqldump import ( "database/sql" "errors" + "io" "os" + "path" + "time" ) -// Dumper represents a database. -type Dumper struct { - db *sql.DB - format string - dir string -} - /* -Creates a new dumper. +Register a new dumper. db: Database that will be dumped (https://golang.org/pkg/database/sql/#DB). dir: Path to the directory where the dumps will be stored. format: Format to be used to name each dump file. Uses time.Time.Format (https://golang.org/pkg/time/#Time.Format). format appended with '.sql'. */ -func Register(db *sql.DB, dir, format string) (*Dumper, error) { +func Register(db *sql.DB, dir, format string) (*Data, error) { if !isDir(dir) { return nil, errors.New("Invalid directory") } - return &Dumper{ - db: db, - format: format, - dir: dir, + name := time.Now().Format(format) + p := path.Join(dir, name+".sql") + + // Check dump directory + if e, _ := exists(p); e { + return nil, errors.New("Dump '" + name + "' already exists.") + } + + // Create .sql file + f, err := os.Create(p) + + if err != nil { + return nil, err + } + + return &Data{ + Out: f, + Connection: db, }, nil } -// Closes the dumper. -// Will also close the database the dumper is connected to. +// Dump Creates a MYSQL dump from the connection to the stream. +func Dump(db *sql.DB, out io.Writer) error { + return (&Data{ + Connection: db, + Out: out, + }).Dump() +} + +// Close the dumper. +// Will also close the database the dumper is connected to as well as the out stream if it has a Close method. // // Not required. -func (d *Dumper) Close() error { +func (d *Data) Close() error { defer func() { - d.db = nil + d.Connection = nil + d.Out = nil }() - return d.db.Close() + if out, ok := d.Out.(io.Closer); ok { + out.Close() + } + return d.Connection.Close() } func exists(p string) (bool, os.FileInfo) { @@ -56,13 +78,6 @@ func exists(p string) (bool, os.FileInfo) { return true, fi } -func isFile(p string) bool { - if e, fi := exists(p); e { - return fi.Mode().IsRegular() - } - return false -} - func isDir(p string) bool { if e, fi := exists(p); e { return fi.Mode().IsDir() diff --git a/mysqldump_test.go b/mysqldump_test.go new file mode 100644 index 0000000..26f5114 --- /dev/null +++ b/mysqldump_test.go @@ -0,0 +1,151 @@ +package mysqldump + +import ( + "bytes" + "io/ioutil" + "strings" + "testing" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +const expected = `-- Go SQL Dump ` + Version + ` +-- +-- ------------------------------------------------------ +-- Server version test_version + +/*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */; +/*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */; +/*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */; + SET NAMES utf8mb4 ; +/*!40103 SET @OLD_TIME_ZONE=@@TIME_ZONE */; +/*!40103 SET TIME_ZONE='+00:00' */; +/*!40014 SET @OLD_UNIQUE_CHECKS=@@UNIQUE_CHECKS, UNIQUE_CHECKS=0 */; +/*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */; +/*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */; +/*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */; + +-- +-- Table structure for table ~Test_Table~ +-- + +DROP TABLE IF EXISTS ~Test_Table~; +/*!40101 SET @saved_cs_client = @@character_set_client */; + SET character_set_client = utf8mb4 ; +CREATE TABLE 'Test_Table' (~id~ int(11) NOT NULL AUTO_INCREMENT,~email~ char(60) DEFAULT NULL, ~name~ char(60), PRIMARY KEY (~id~))ENGINE=InnoDB DEFAULT CHARSET=latin1; +/*!40101 SET character_set_client = @saved_cs_client */; + +-- +-- Dumping data for table ~Test_Table~ +-- + +LOCK TABLES ~Test_Table~ WRITE; +/*!40000 ALTER TABLE ~Test_Table~ DISABLE KEYS */; +INSERT INTO ~Test_Table~ VALUES ('1',NULL,'Test Name 1'),('2','test2@test.de','Test Name 2'); +/*!40000 ALTER TABLE ~Test_Table~ ENABLE KEYS */; +UNLOCK TABLES; +/*!40103 SET TIME_ZONE=@OLD_TIME_ZONE */; + +/*!40101 SET SQL_MODE=@OLD_SQL_MODE */; +/*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */; +/*!40014 SET UNIQUE_CHECKS=@OLD_UNIQUE_CHECKS */; +/*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */; +/*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */; +/*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */; +/*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */; + +` + +func RunDump(t testing.TB, data *Data) { + db, mock, err := sqlmock.New() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer db.Close() + + data.Connection = db + showTablesRows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). + AddRow("Test_Table") + + serverVersionRows := sqlmock.NewRows([]string{"Version()"}). + AddRow("test_version") + + createTableRows := sqlmock.NewRows([]string{"Table", "Create Table"}). + AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`email` char(60) DEFAULT NULL, `name` char(60), PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") + + createTableValueRows := sqlmock.NewRows([]string{"id", "email", "name"}). + AddRow(1, nil, "Test Name 1"). + AddRow(2, "test2@test.de", "Test Name 2") + + mock.ExpectBegin() + mock.ExpectQuery(`^SELECT version\(\)$`).WillReturnRows(serverVersionRows) + mock.ExpectQuery(`^SHOW TABLES$`).WillReturnRows(showTablesRows) + mock.ExpectExec("^LOCK TABLES `Test_Table` READ /\\*!32311 LOCAL \\*/$").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(createTableRows) + mock.ExpectQuery("^SELECT (.+) FROM `Test_Table`$").WillReturnRows(createTableValueRows) + mock.ExpectRollback() + + assert.NoError(t, data.Dump(), "an error was not expected when dumping a stub database connection") +} + +func TestDumpOk(t *testing.T) { + var buf bytes.Buffer + + RunDump(t, &Data{ + Out: &buf, + LockTables: true, + }) + + result := strings.Replace(strings.Split(buf.String(), "-- Dump completed")[0], "`", "~", -1) + + assert.Equal(t, expected, result) +} + +func TestNoLockOk(t *testing.T) { + var buf bytes.Buffer + + data := &Data{ + Out: &buf, + LockTables: false, + } + + db, mock, err := sqlmock.New() + assert.NoError(t, err, "an error was not expected when opening a stub database connection") + defer db.Close() + + data.Connection = db + showTablesRows := sqlmock.NewRows([]string{"Tables_in_Testdb"}). + AddRow("Test_Table") + + serverVersionRows := sqlmock.NewRows([]string{"Version()"}). + AddRow("test_version") + + createTableRows := sqlmock.NewRows([]string{"Table", "Create Table"}). + AddRow("Test_Table", "CREATE TABLE 'Test_Table' (`id` int(11) NOT NULL AUTO_INCREMENT,`email` char(60) DEFAULT NULL, `name` char(60), PRIMARY KEY (`id`))ENGINE=InnoDB DEFAULT CHARSET=latin1") + + createTableValueRows := sqlmock.NewRows([]string{"id", "email", "name"}). + AddRow(1, nil, "Test Name 1"). + AddRow(2, "test2@test.de", "Test Name 2") + + mock.ExpectBegin() + mock.ExpectQuery(`^SELECT version\(\)$`).WillReturnRows(serverVersionRows) + mock.ExpectQuery(`^SHOW TABLES$`).WillReturnRows(showTablesRows) + mock.ExpectQuery("^SHOW CREATE TABLE `Test_Table`$").WillReturnRows(createTableRows) + mock.ExpectQuery("^SELECT (.+) FROM `Test_Table`$").WillReturnRows(createTableValueRows) + mock.ExpectRollback() + + assert.NoError(t, data.Dump(), "an error was not expected when dumping a stub database connection") + + result := strings.Replace(strings.Split(buf.String(), "-- Dump completed")[0], "`", "~", -1) + + assert.Equal(t, expected, result) +} + +func BenchmarkDump(b *testing.B) { + data := &Data{ + Out: ioutil.Discard, + LockTables: true, + } + for i := 0; i < b.N; i++ { + RunDump(b, data) + } +} diff --git a/sanitize.go b/sanitize.go new file mode 100644 index 0000000..1f8fca3 --- /dev/null +++ b/sanitize.go @@ -0,0 +1,27 @@ +package mysqldump + +import "strings" + +var lazyMySQLReplacer *strings.Replacer + +// sanitize MySQL based on +// https://dev.mysql.com/doc/refman/8.0/en/string-literals.html table 9.1 +// needs to be placed in either a single or a double quoted string +func sanitize(input string) string { + if lazyMySQLReplacer == nil { + lazyMySQLReplacer = strings.NewReplacer( + "\x00", "\\0", + "'", "\\'", + "\"", "\\\"", + "\b", "\\b", + "\n", "\\n", + "\r", "\\r", + // "\t", "\\t", Tab literals are acceptable in reads + "\x1A", "\\Z", // ASCII 26 == x1A + "\\", "\\\\", + // "%", "\\%", + // "_", "\\_", + ) + } + return lazyMySQLReplacer.Replace(input) +} diff --git a/sanitize_test.go b/sanitize_test.go new file mode 100644 index 0000000..6303013 --- /dev/null +++ b/sanitize_test.go @@ -0,0 +1,25 @@ +package mysqldump + +import ( + "fmt" + "testing" +) + +func TestForSQLInjection(t *testing.T) { + examples := [][]string{ + /** Query ** Input ** Expected **/ + {"SELECT * WHERE field = '%s';", "test", "SELECT * WHERE field = 'test';"}, + {"'%s'", "'; DROP TABLES `test`;", "'\\'; DROP TABLES `test`;'"}, + {"'%s'", "'+(SELECT name FROM users LIMIT 1)+'", "'\\'+(SELECT name FROM users LIMIT 1)+\\''"}, + {"SELECT '%s'", "\x00x633A5C626F6F742E696E69", "SELECT '\\0x633A5C626F6F742E696E69'"}, + {"WHERE PASSWORD('%s')", "') OR 1=1--", "WHERE PASSWORD('\\') OR 1=1--')"}, + } + var query string + for _, example := range examples { + query = fmt.Sprintf(example[0], sanitize(example[1])) + + if example[2] != query { + t.Fatalf("expected %#v, got %#v", example[2], query) + } + } +}