diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4f02f79 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +coverage.out +tmp.out +profile.out diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 0000000..41b4bcf --- /dev/null +++ b/.travis.yml @@ -0,0 +1,40 @@ +language: go + +matrix: + fast_finish: true + include: + - go: 1.11.x + env: GO111MODULE=on + - go: 1.12.x + env: GO111MODULE=on + - go: 1.13.x + - go: 1.13.x + env: + - TESTTAGS=nomsgpack + - go: 1.14.x + - go: 1.14.x + env: + - TESTTAGS=nomsgpack + - go: master + +git: + depth: 10 + +before_install: + - if [[ "${GO111MODULE}" = "on" ]]; then mkdir "${HOME}/go"; export GOPATH="${HOME}/go"; fi + +install: + - if [[ "${GO111MODULE}" = "on" ]]; then go mod download; fi + - if [[ "${GO111MODULE}" = "on" ]]; then export PATH="${GOPATH}/bin:${GOROOT}/bin:${PATH}"; fi + - if [[ "${GO111MODULE}" = "on" ]]; then make tools; fi + +go_import_path: github.com/larapulse/migrator + +script: + - make vet + - make fmt-check + - make misspell-check + - make test + +after_success: + - bash <(curl -s https://codecov.io/bash) diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..240ff22 --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,25 @@ +MIT License +----------- + +Copyright (c) 2020 Sergey Podgornyy (https://podgornyy.rocks) + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ac727e8 --- /dev/null +++ b/Makefile @@ -0,0 +1,68 @@ +GOFMT ?= gofmt "-s" +PACKAGES ?= $(shell go list ./...) +GOFILES := $(shell find . -name "*.go") +TESTTAGS ?= "" + +.PHONY: test +test: + echo "mode: count" > coverage.out + for d in $(PACKAGES); do \ + go test -timeout 30s -tags $(TESTTAGS) -v -covermode=count -coverprofile=profile.out $$d > tmp.out; \ + cat tmp.out; \ + if grep -q "^--- FAIL" tmp.out; then \ + rm tmp.out; \ + exit 1; \ + elif grep -q "build failed" tmp.out; then \ + rm tmp.out; \ + exit 1; \ + elif grep -q "setup failed" tmp.out; then \ + rm tmp.out; \ + exit 1; \ + fi; \ + if [ -f profile.out ]; then \ + cat profile.out | grep -v "mode:" >> coverage.out; \ + rm profile.out; \ + fi; \ + done + +.PHONY: fmt +fmt: + $(GOFMT) -w $(GOFILES) + +.PHONY: fmt-check +fmt-check: + @diff=$$($(GOFMT) -d $(GOFILES)); \ + if [ -n "$$diff" ]; then \ + echo "Please run 'make fmt' and commit the result:"; \ + echo "$${diff}"; \ + exit 1; \ + fi; + +vet: + go vet $(PACKAGES) + +.PHONY: lint +lint: + @hash golint > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + go get -u golang.org/x/lint/golint; \ + fi + for PKG in $(PACKAGES); do golint -set_exit_status $$PKG || exit 1; done; + +.PHONY: misspell-check +misspell-check: + @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + go get -u github.com/client9/misspell/cmd/misspell; \ + fi + misspell -error $(GOFILES) + +.PHONY: misspell +misspell: + @hash misspell > /dev/null 2>&1; if [ $$? -ne 0 ]; then \ + go get -u github.com/client9/misspell/cmd/misspell; \ + fi + misspell -w $(GOFILES) + +.PHONY: tools +tools: + go install golang.org/x/lint/golint; \ + go install github.com/client9/misspell/cmd/misspell; diff --git a/README.md b/README.md new file mode 100644 index 0000000..3976569 --- /dev/null +++ b/README.md @@ -0,0 +1,217 @@ +# MySQL database migrator + + + +[![Build Status](https://travis-ci.org/larapulse/migrator.svg)](https://travis-ci.org/larapulse/migrator) +[![Software License](https://img.shields.io/badge/license-MIT-brightgreen.svg)](LICENSE.md) +[![codecov](https://codecov.io/gh/larapulse/migrator/branch/master/graph/badge.svg)](https://codecov.io/gh/larapulse/migrator) +[![Go Report Card](https://goreportcard.com/badge/github.com/larapulse/migrator)](https://goreportcard.com/report/github.com/larapulse/migrator) +[![GoDoc](https://godoc.org/github.com/larapulse/migrator?status.svg)](https://pkg.go.dev/github.com/larapulse/migrator?tab=doc) +[![Release](https://img.shields.io/github/release/larapulse/migrator.svg)](https://github.com/larapulse/migrator/releases) +[![TODOs](https://badgen.net/https/api.tickgit.com/badgen/github.com/larapulse/migrator)](https://www.tickgit.com/browse?repo=github.com/larapulse/migrator) + +MySQL database migrator designed to run migrations to your features and manage database schema update with intuitive go code. It is compatible with latest MySQL v8. + +## Installation + +To install `migrator` package, you need to install Go and set your Go workspace first. + +1. The first need [Go](https://golang.org/) installed (**version 1.13+ is required**), then you can use the below Go command to install `migrator`. + +```sh +$ go get -u github.com/larapulse/migrator +``` + +2. Import it in your code: + +```go +import "github.com/larapulse/migrator" +``` + +## Quick start + +Initialize migrator with migration entries: + +```go +var migrations = []migrator.Migration{ + { + Name: "19700101_0001_create_posts_table", + Up: func() migrator.Schema { + var s migrator.Schema + posts := migrator.Table{Name: "posts"} + + posts.UniqueID("id") + posts.Column("title", migrator.String{Precision: 64}) + posts.Column("content", migrator.Text{}) + posts.Timestamps() + + s.CreateTable(posts) + + return s + }, + Down: func() migrator.Schema { + var s migrator.Schema + + s.DropTable("posts") + + return s + }, + }, + { + Name: "19700101_0002_create_comments_table", + Up: func() migrator.Schema { + var s migrator.Schema + comments := migrator.Table{Name: "comments"} + + comments.UniqueID("id") + comments.UUID("post_id", "", false) + comments.Column("title", migrator.String{Precision: 64}) + comments.Column("content", migrator.Text{}) + comments.Timestamps() + + comments.Foreign("post_id", "id", "posts", "RESTRICT", "RESTRICT") + + s.CreateTable(comments) + + return s + }, + Down: func() migrator.Schema { + var s migrator.Schema + + s.DropTable("comments") + + return s + }, + }, +} + +m := migrator.Migrator{Pool: migrations} +migrated, err = m.Migrate(db) + +if err != nil { + fmt.Errorf("Could not migrate: %v", err) + os.Exit(1) +} + +if len(migrated) == 0 { + fmt.Println("Nothing were migrated.") +} + +for _, m := range migrated { + fmt.Println("Migration: "+m+" was migrated ✅") +} + +fmt.Println("Migration did run successfully") +``` + +After first migration run, `migrations` table will be created: + +``` ++----+-------------------------------------+-------+---------------------+ +| id | name | batch | applied_at | ++----+-------------------------------------+-------+---------------------+ +| 1 | 19700101_0001_create_posts_table | 1 | 2020-06-27 00:00:00 | +| 2 | 19700101_0002_create_comments_table | 1 | 2020-06-27 00:00:00 | ++----+-------------------------------------+-------+---------------------+ +``` + +If you want to use another name for migration table, change it `Migrator` before running migrations: + +```go +m := migrator.Migrator{TableName: "_my_app_migrations"} +``` + +### Transactional migration + +In case you have multiple commands within one migration and you want to be sure it is migrated properly, you might enable transactional execution per migration: + +```go +var migration = migrator.Migration{ + Name: "19700101_0001_create_posts_and_users_tables", + Up: func() migrator.Schema { + var s migrator.Schema + posts := migrator.Table{Name: "posts"} + posts.UniqueID("id") + posts.Timestamps() + + users := migrator.Table{Name: "users"} + users.UniqueID("id") + users.Timestamps() + + s.CreateTable(posts) + s.CreateTable(users) + + return s + }, + Down: func() migrator.Schema { + var s migrator.Schema + + s.DropTable("users") + s.DropTable("posts") + + return s + }, + Transaction: true, +} +``` + +### Rollback and revert + +In case you need to revert your deploy and DB, you can revert last migrated batch: + +```go +m := migrator.Migrator{Pool: migrations} +reverted, err := m.Rollback(db) + +if err != nil { + fmt.Errorf("Could not roll back migrations: %v", err) + os.Exit(1) +} + +if len(reverted) == 0 { + fmt.Println("Nothing were rolled back.") +} + +for _, m := range reverted { + fmt.Println("Migration: "+m+" was rolled back ✅") +} +``` + +To revert all migrated items back, you have to call `Revert()` on your `migrator`: + +```go +m := migrator.Migrator{Pool: migrations} +reverted, err := m.Revert(db) +``` + +## Customize queries + +You may add any column definition to the database on your own, just be sure you implement `columnType` interface: + +```go +type customType string + +func (ct customType) buildRow() string { + return string(ct) +} + +posts := migrator.Table{Name: "posts"} +posts.UniqueID("id") +posts.Column("data", customType("json not null")) +posts.Timestamps() +``` + +Same logic is for adding custom commands to the Schema to be migrated or reverted, just be sure you implement `command` interface: + +```go +type customCommand string + +func (cc customCommand) toSQL() string { + return string(cc) +} + +var s migrator.Schema + +c := customCommand("DROP PROCEDURE abc") +s.CustomCommand(c) +``` diff --git a/column.go b/column.go new file mode 100644 index 0000000..c709d4e --- /dev/null +++ b/column.go @@ -0,0 +1,533 @@ +// Package migrator represents MySQL database migrator +package migrator + +import ( + "fmt" + "strconv" + "strings" +) + +type columns []column + +func (c columns) render() string { + rows := []string{} + + for _, item := range c { + rows = append(rows, fmt.Sprintf("`%s` %s", item.field, item.definition.buildRow())) + } + + return strings.Join(rows, ", ") +} + +type column struct { + field string + definition columnType +} + +type columnType interface { + buildRow() string +} + +// Integer represents integer value in DB: {tiny,small,medium,big}int +// +// Default migrator.Integer will build a sql row: `int NOT NULL` +// Examples: +// tinyint ➡️ migrator.Integer{Prefix: "tiny", Unsigned: true, Precision: 1, Default: "0"} +// ↪️ tinyint(1) unsigned NOT NULL DEFAULT 0 +// int ➡️ migrator.Integer{Nullable: true, OnUpdate: "set null", Comment: "nullable counter"} +// ↪️ int NULL ON UPDATE set null COMMENT 'nullable counter' +// mediumint ➡️ migrator.Integer{Prefix: "medium", Precision: "255"} +// ↪️ mediumint(255) NOT NULL +// bigint ➡️ migrator.Integer{Prefix: "big", Unsigned: true, Precision: "255", Autoincrement: true} +// ↪️ bigint(255) unsigned NOT NULL AUTO_INCREMENT +type Integer struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Prefix string // tiny, small, medium, big + Unsigned bool + Precision uint16 + Autoincrement bool +} + +func (i Integer) buildRow() string { + sql := i.Prefix + "int" + if i.Precision > 0 { + sql += fmt.Sprintf("(%s)", strconv.Itoa(int(i.Precision))) + } + + if i.Unsigned { + sql += " unsigned" + } + + if i.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if i.Default != "" { + sql += " DEFAULT " + i.Default + } + + if i.Autoincrement { + sql += " AUTO_INCREMENT" + } + + if i.OnUpdate != "" { + sql += " ON UPDATE " + i.OnUpdate + } + + if i.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", i.Comment) + } + + return sql +} + +// Floatable replresents number with floating point in DB: +// `float`, `double` or `decimal` +// +// Default migrator.Floatable will build a sql row: `float NOT NULL` +// Examples: +// float ➡️ migrator.Floatable{Precision: 2, Nullable: true} +// ↪️ float(2) NULL +// real ➡️ migrator.Floatable{Type: "real", Precision: 5, Scale: 2} +// ↪️ real(5,2) NOT NULL +// double ➡️ migrator.Floatable{Type: "double", Scale: 2, Unsigned: true} +// ↪️ double(0,2) unsigned NOT NULL +// decimal ➡️ migrator.Floatable{Type: "decimal", Precision: 15, Scale: 2, OnUpdate: "0.0", Comment: "money"} +// ↪️ decimal(15,2) NOT NULL ON UPDATE 0.0 COMMENT 'money' +// numeric ➡️ migrator.Floatable{Type: "numeric", Default: "0.0"} +// ↪️ numeric NOT NULL DEFAULT 0.0 +type Floatable struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Type string // float, real, double, decimal, numeric + Unsigned bool + Precision uint16 + Scale uint16 +} + +func (f Floatable) buildRow() string { + sql := f.Type + + if sql == "" { + sql = "float" + } + + if f.Scale > 0 { + sql += fmt.Sprintf("(%s,%s)", strconv.Itoa(int(f.Precision)), strconv.Itoa(int(f.Scale))) + } else if f.Precision > 0 { + sql += fmt.Sprintf("(%s)", strconv.Itoa(int(f.Precision))) + } + + if f.Unsigned { + sql += " unsigned" + } + + if f.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if f.Default != "" { + sql += " DEFAULT " + f.Default + } + + if f.OnUpdate != "" { + sql += " ON UPDATE " + f.OnUpdate + } + + if f.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", f.Comment) + } + + return sql +} + +// Timable represents DB representation of timable column type: +// `date`, `datetime`, `timestamp`, `time` or `year` +// +// Default migrator.Timable will build a sql row: `timestamp NOT NULL` +// Examples: +// date ➡️ migrator.Timable{Type: "date", Nullable: true} +// ↪️ date NULL +// datetime ➡️ migrator.Timable{Type: "datetime", Default: "CURRENT_TIMESTAMP"} +// ↪️ datetime NOT NULL DEFAULT CURRENT_TIMESTAMP +// timestamp ➡️ migrator.Timable{Default: "CURRENT_TIMESTAMP", OnUpdate: "CURRENT_TIMESTAMP"} +// ↪️ timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +// time ➡️ migrator.Timable{Type: "time", Comment: "meeting time"} +// ↪️ time NOT NULL COMMENT 'meeting time' +// year ➡️ migrator.Timable{Type: "year", Nullable: true} +// ↪️ year NULL +type Timable struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Type string // date, time, datetime, timestamp, year +} + +func (t Timable) buildRow() string { + sql := t.Type + + if sql == "" { + sql = "timestamp" + } + + if t.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if t.Default != "" { + sql += " DEFAULT " + t.Default + } + + if t.OnUpdate != "" { + sql += " ON UPDATE " + t.OnUpdate + } + + if t.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", t.Comment) + } + + return sql +} + +// String represents basic DB string column type: `char` or `varchar` +// +// Default migrator.String will build a sql row: `varchar COLLATE utf8mb4_unicode_ci NOT NULL` +// Examples: +// char ➡️ migrator.String{Fixed: true, Precision: 36, Nullable: true, OnUpdate: "set null", Comment: "uuid"} +// ↪️ char(36) COLLATE utf8mb4_unicode_ci NULL ON UPDATE set null COMMENT 'uuid' +// varchar ➡️ migrator.String{Precision: 255, Default: "active", Charset: "utf8mb4", Collate: "utf8mb4_general_ci"} +// ↪️ varchar(255) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT 'active' +type String struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Charset string + Collate string + + Fixed bool // char for fixed, otherwise varchar + Precision uint16 +} + +func (s String) buildRow() string { + sql := "" + + if !s.Fixed { + sql += "var" + } + + sql += "char" + + if s.Precision > 0 { + sql += fmt.Sprintf("(%s)", strconv.Itoa(int(s.Precision))) + } + + if s.Charset != "" { + sql += " CHARACTER SET " + s.Charset + } + + if s.Collate != "" { + sql += " COLLATE " + s.Collate + } else if s.Charset == "" { + // use default + sql += " COLLATE utf8mb4_unicode_ci" + } + + if s.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if s.Default != "" { + sql += fmt.Sprintf(" DEFAULT '%s'", s.Default) + } + + if s.OnUpdate != "" { + sql += " ON UPDATE " + s.OnUpdate + } + + if s.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", s.Comment) + } + + return sql +} + +// Text represents long text column type represented in DB as: +// - {tiny,medium,long}text +// - {tiny,medium,long}blob +// +// Default migrator.Text will build a sql row: `text COLLATE utf8mb4_unicode_ci NOT NULL` +// Examples: +// tinytext ➡️ migrator.Text{Prefix: "tiny"} +// ↪️ tinytext COLLATE utf8mb4_unicode_ci NOT NULL +// text ➡️ migrator.Text{Nullable: true, OnUpdate: "set null", Comment: "write your comment here"} +// ↪️ text COLLATE utf8mb4_unicode_ci NULL ON UPDATE set null COMMENT 'write your comment here' +// mediumtext ➡️ migrator.Text{Prefix: "medium"} +// ↪️ mediumtext COLLATE utf8mb4_unicode_ci NOT NULL +// longtext ➡️ migrator.Text{Prefix: "long", Default: "write you text", Charset: "utf8mb4", Collate: "utf8mb4_general_ci"} +// ↪️ longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NOT NULL DEFAULT 'write you text' +// tinyblob ➡️ migrator.Text{Prefix: "tiny", Blob: true} +// ↪️ tinyblob COLLATE utf8mb4_unicode_ci NOT NULL +// blob ➡️ migrator.Text{Blob: true} +// ↪️ blob COLLATE utf8mb4_unicode_ci NOT NULL +// mediumblob ➡️ migrator.Text{Prefix: "medium", Blob: true} +// ↪️ mediumblob COLLATE utf8mb4_unicode_ci NOT NULL +// longblob ➡️ migrator.Text{Prefix: "long", Blob: true} +// ↪️ longblob COLLATE utf8mb4_unicode_ci NOT NULL +type Text struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Charset string + Collate string + + Prefix string // tiny, medium, long + Blob bool // for binary +} + +func (t Text) buildRow() string { + sql := t.Prefix + + if t.Blob { + sql += "blob" + } else { + sql += "text" + } + + if t.Charset != "" { + sql += " CHARACTER SET " + t.Charset + } + + if t.Collate != "" { + sql += " COLLATE " + t.Collate + } else if t.Charset == "" { + // use default + sql += " COLLATE utf8mb4_unicode_ci" + } + + if t.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if t.Default != "" { + sql += fmt.Sprintf(" DEFAULT '%s'", t.Default) + } + + if t.OnUpdate != "" { + sql += " ON UPDATE " + t.OnUpdate + } + + if t.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", t.Comment) + } + + return sql +} + +// JSON represents DB column type `json` +// +// Default migrator.JSON will build a sql row: `json NOT NULL` +// Examples: +// ➡️ migrator.JSON{Nullable: true, Comment: "user data"} +// ↪️ json NULL COMMENT 'user data' +// ➡️ migrator.JSON{Default: "{}", OnUpdate: "{}"} +// ↪️ json NOT NULL DEFAULT '{}' ON UPDATE {} +type JSON struct { + Default string + Nullable bool + Comment string + OnUpdate string +} + +func (j JSON) buildRow() string { + sql := "json" + + if j.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if j.Default != "" { + sql += fmt.Sprintf(" DEFAULT '%s'", j.Default) + } + + if j.OnUpdate != "" { + sql += " ON UPDATE " + j.OnUpdate + } + + if j.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", j.Comment) + } + + return sql +} + +// Enum represents choisable value. In database represented by: `enum` or `set` +// +// Default migrator.Enum will build a sql row: `enum('') NOT NULL` +// Examples: +// enum ➡️ migrator.Enum{Values: []string{"on", "off"}, Default: "off", Nullable: true, OnUpdate: "set null"} +// ↪️ enum('on', 'off') NULL DEFAULT 'off' ON UPDATE set null +// set ➡️ migrator.Enum{Values: []string{"1", "2", "3"}, Comment: "options"} +// ↪️ set('1', '2', '3') NOT NULL COMMENT 'options' +type Enum struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Values []string + Multiple bool // "set", otherwise "enum" +} + +func (e Enum) buildRow() string { + sql := "" + + if e.Multiple { + sql += "set" + } else { + sql += "enum" + } + + sql += "('" + strings.Join(e.Values, "', '") + "')" + + if e.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if e.Default != "" { + sql += fmt.Sprintf(" DEFAULT '%s'", e.Default) + } + + if e.OnUpdate != "" { + sql += " ON UPDATE " + e.OnUpdate + } + + if e.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", e.Comment) + } + + return sql +} + +// Bit represents default `bit` column type +// +// Default migrator.Bit will build a sql row: `bit NOT NULL` +// Examples: +// ➡️ migrator.Bit{Precision: 8, Default: "1", Comment: "mario game code"} +// ↪️ bit(8) NOT NULL DEFAULT 1 COMMENT 'mario game code' +// ➡️ migrator.Bit{Precision: 64, Nullable: true, OnUpdate: "set null"} +// ↪️ bit(64) NULL ON UPDATE set null +type Bit struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Precision uint16 +} + +func (b Bit) buildRow() string { + sql := "bit" + + if b.Precision > 0 { + sql += "(" + strconv.Itoa(int(b.Precision)) + ")" + } + + if b.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if b.Default != "" { + sql += " DEFAULT " + b.Default + } + + if b.OnUpdate != "" { + sql += " ON UPDATE " + b.OnUpdate + } + + if b.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", b.Comment) + } + + return sql +} + +// Binary represents binary column type: `binary` or `varbinary` +// +// Default migrator.Binary will build a sql row: `varbinary NOT NULL` +// Examples: +// binary ➡️ migrator.Binary{Fixed: true, Precision: 36, Default: "1", Comment: "uuid"} +// ↪️ binary(36) NOT NULL DEFAULT 1 COMMENT 'uuid' +// varbinary ➡️ migrator.Binary{Precision: 255, Nullable: true, OnUpdate: "set null"} +// ↪️ varbinary(255) NULL ON UPDATE set null +type Binary struct { + Default string + Nullable bool + Comment string + OnUpdate string + + Fixed bool // binary for fixed, otherwise varbinary + Precision uint16 +} + +func (b Binary) buildRow() string { + sql := "" + + if !b.Fixed { + sql += "var" + } + + sql += "binary" + + if b.Precision > 0 { + sql += fmt.Sprintf("(%s)", strconv.Itoa(int(b.Precision))) + } + + if b.Nullable { + sql += " NULL" + } else { + sql += " NOT NULL" + } + + if b.Default != "" { + sql += " DEFAULT " + b.Default + } + + if b.OnUpdate != "" { + sql += " ON UPDATE " + b.OnUpdate + } + + if b.Comment != "" { + sql += fmt.Sprintf(" COMMENT '%s'", b.Comment) + } + + return sql +} diff --git a/column_test.go b/column_test.go new file mode 100644 index 0000000..ac556e2 --- /dev/null +++ b/column_test.go @@ -0,0 +1,545 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type testColumnType string + +func (c testColumnType) buildRow() string { + return string(c) +} + +func TestColumnRender(t *testing.T) { + t.Run("it renders row from one column", func(t *testing.T) { + c := columns{column{"test", testColumnType("run")}} + + assert.Equal(t, "`test` run", c.render()) + }) + + t.Run("it renders row from multiple columns", func(t *testing.T) { + c := columns{ + column{"test", testColumnType("run")}, + column{"again", testColumnType("me")}, + } + + assert.Equal(t, "`test` run, `again` me", c.render()) + }) +} + +func TestInteger(t *testing.T) { + t.Run("it builds basic column type", func(t *testing.T) { + c := Integer{} + assert.Equal(t, "int NOT NULL", c.buildRow()) + }) + + t.Run("it build with prefix", func(t *testing.T) { + c := Integer{Prefix: "super"} + assert.Equal(t, "superint NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision", func(t *testing.T) { + c := Integer{Precision: 20} + assert.Equal(t, "int(20) NOT NULL", c.buildRow()) + }) + + t.Run("it builds unsigned", func(t *testing.T) { + c := Integer{Unsigned: true} + assert.Equal(t, "int unsigned NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Integer{Nullable: true} + assert.Equal(t, "int NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Integer{Default: "0"} + assert.Equal(t, "int NOT NULL DEFAULT 0", c.buildRow()) + }) + + t.Run("it builds with autoincrement", func(t *testing.T) { + c := Integer{Autoincrement: true} + assert.Equal(t, "int NOT NULL AUTO_INCREMENT", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Integer{OnUpdate: "set null"} + assert.Equal(t, "int NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Integer{Comment: "test"} + assert.Equal(t, "int NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Integer{ + Prefix: "big", + Precision: 10, + Unsigned: true, + Nullable: true, + Default: "100", + Autoincrement: true, + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "bigint(10) unsigned NULL DEFAULT 100 AUTO_INCREMENT ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestFloatable(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := Floatable{} + assert.Equal(t, "float NOT NULL", c.buildRow()) + }) + + t.Run("it builds basic column type", func(t *testing.T) { + c := Floatable{Type: "real"} + assert.Equal(t, "real NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision", func(t *testing.T) { + c := Floatable{Type: "double", Precision: 20} + assert.Equal(t, "double(20) NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision and scale", func(t *testing.T) { + c := Floatable{Type: "decimal", Precision: 10, Scale: 2} + assert.Equal(t, "decimal(10,2) NOT NULL", c.buildRow()) + }) + + t.Run("it builds unsigned", func(t *testing.T) { + c := Floatable{Unsigned: true} + assert.Equal(t, "float unsigned NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Floatable{Nullable: true} + assert.Equal(t, "float NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Floatable{Default: "0.0"} + assert.Equal(t, "float NOT NULL DEFAULT 0.0", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Floatable{OnUpdate: "set null"} + assert.Equal(t, "float NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Floatable{Comment: "test"} + assert.Equal(t, "float NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Floatable{ + Type: "decimal", + Precision: 10, + Scale: 2, + Unsigned: true, + Nullable: true, + Default: "100.0", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "decimal(10,2) unsigned NULL DEFAULT 100.0 ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestTimeable(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := Timable{} + assert.Equal(t, "timestamp NOT NULL", c.buildRow()) + }) + + t.Run("it builds basic column type", func(t *testing.T) { + c := Timable{Type: "datetime"} + assert.Equal(t, "datetime NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Timable{Nullable: true} + assert.Equal(t, "timestamp NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Timable{Default: "CURRENT_TIMESTAMP"} + assert.Equal(t, "timestamp NOT NULL DEFAULT CURRENT_TIMESTAMP", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Timable{OnUpdate: "set null"} + assert.Equal(t, "timestamp NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Timable{Comment: "test"} + assert.Equal(t, "timestamp NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Timable{ + Type: "datetime", + Nullable: true, + Default: "CURRENT_TIMESTAMP", + OnUpdate: "CURRENT_TIMESTAMP", + Comment: "test", + } + + assert.Equal( + t, + "datetime NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestString(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := String{} + assert.Equal(t, "varchar COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds fixed", func(t *testing.T) { + c := String{Fixed: true} + assert.Equal(t, "char COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision", func(t *testing.T) { + c := String{Precision: 255} + assert.Equal(t, "varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds with charset", func(t *testing.T) { + c := String{Charset: "utf8"} + assert.Equal(t, "varchar CHARACTER SET utf8 NOT NULL", c.buildRow()) + }) + + t.Run("it builds with collate", func(t *testing.T) { + c := String{Collate: "utf8mb4_general_ci"} + assert.Equal(t, "varchar COLLATE utf8mb4_general_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := String{Nullable: true} + assert.Equal(t, "varchar COLLATE utf8mb4_unicode_ci NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := String{Default: "done"} + assert.Equal(t, "varchar COLLATE utf8mb4_unicode_ci NOT NULL DEFAULT 'done'", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := String{OnUpdate: "set null"} + assert.Equal(t, "varchar COLLATE utf8mb4_unicode_ci NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := String{Comment: "test"} + assert.Equal(t, "varchar COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := String{ + Fixed: true, + Precision: 36, + Nullable: true, + Charset: "utf8mb4", + Collate: "utf8mb4_general_ci", + Default: "nice", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "char(36) CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT 'nice' ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestText(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := Text{} + assert.Equal(t, "text COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds with prefix", func(t *testing.T) { + c := Text{Prefix: "medium"} + assert.Equal(t, "mediumtext COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds blob", func(t *testing.T) { + c := Text{Blob: true} + assert.Equal(t, "blob COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds blob with prefix", func(t *testing.T) { + c := Text{Prefix: "tiny", Blob: true} + assert.Equal(t, "tinyblob COLLATE utf8mb4_unicode_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds with charset", func(t *testing.T) { + c := Text{Charset: "utf8"} + assert.Equal(t, "text CHARACTER SET utf8 NOT NULL", c.buildRow()) + }) + + t.Run("it builds with collate", func(t *testing.T) { + c := Text{Collate: "utf8mb4_general_ci"} + assert.Equal(t, "text COLLATE utf8mb4_general_ci NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Text{Nullable: true} + assert.Equal(t, "text COLLATE utf8mb4_unicode_ci NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Text{Default: "done"} + assert.Equal(t, "text COLLATE utf8mb4_unicode_ci NOT NULL DEFAULT 'done'", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Text{OnUpdate: "set null"} + assert.Equal(t, "text COLLATE utf8mb4_unicode_ci NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Text{Comment: "test"} + assert.Equal(t, "text COLLATE utf8mb4_unicode_ci NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Text{ + Prefix: "long", + Blob: true, + Nullable: true, + Charset: "utf8mb4", + Collate: "utf8mb4_general_ci", + Default: "nice", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "longblob CHARACTER SET utf8mb4 COLLATE utf8mb4_general_ci NULL DEFAULT 'nice' ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestJson(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := JSON{} + assert.Equal(t, "json NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := JSON{Nullable: true} + assert.Equal(t, "json NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := JSON{Default: "{}"} + assert.Equal(t, "json NOT NULL DEFAULT '{}'", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := JSON{OnUpdate: "set null"} + assert.Equal(t, "json NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := JSON{Comment: "test"} + assert.Equal(t, "json NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := JSON{ + Nullable: true, + Default: "{}", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "json NULL DEFAULT '{}' ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestEnum(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := Enum{} + assert.Equal(t, "enum('') NOT NULL", c.buildRow()) + }) + + t.Run("it builds with multiple flag", func(t *testing.T) { + c := Enum{Multiple: true} + assert.Equal(t, "set('') NOT NULL", c.buildRow()) + }) + + t.Run("it builds with values", func(t *testing.T) { + c := Enum{Values: []string{"active", "inactive"}} + assert.Equal(t, "enum('active', 'inactive') NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Enum{Nullable: true} + assert.Equal(t, "enum('') NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Enum{Default: "valid"} + assert.Equal(t, "enum('') NOT NULL DEFAULT 'valid'", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Enum{OnUpdate: "set null"} + assert.Equal(t, "enum('') NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Enum{Comment: "test"} + assert.Equal(t, "enum('') NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Enum{ + Multiple: true, + Values: []string{"male", "female", "other"}, + Nullable: true, + Default: "male,female", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "set('male', 'female', 'other') NULL DEFAULT 'male,female' ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestBit(t *testing.T) { + t.Run("it builds basic column type", func(t *testing.T) { + c := Bit{} + assert.Equal(t, "bit NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision", func(t *testing.T) { + c := Bit{Precision: 20} + assert.Equal(t, "bit(20) NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Bit{Nullable: true} + assert.Equal(t, "bit NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Bit{Default: "1"} + assert.Equal(t, "bit NOT NULL DEFAULT 1", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Bit{OnUpdate: "set null"} + assert.Equal(t, "bit NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Bit{Comment: "test"} + assert.Equal(t, "bit NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Bit{ + Precision: 10, + Nullable: true, + Default: "0", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "bit(10) NULL DEFAULT 0 ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} + +func TestBinary(t *testing.T) { + t.Run("it builds with default type", func(t *testing.T) { + c := Binary{} + assert.Equal(t, "varbinary NOT NULL", c.buildRow()) + }) + + t.Run("it builds fixed", func(t *testing.T) { + c := Binary{Fixed: true} + assert.Equal(t, "binary NOT NULL", c.buildRow()) + }) + + t.Run("it builds with precision", func(t *testing.T) { + c := Binary{Precision: 255} + assert.Equal(t, "varbinary(255) NOT NULL", c.buildRow()) + }) + + t.Run("it builds nullable column type", func(t *testing.T) { + c := Binary{Nullable: true} + assert.Equal(t, "varbinary NULL", c.buildRow()) + }) + + t.Run("it builds with default value", func(t *testing.T) { + c := Binary{Default: "1"} + assert.Equal(t, "varbinary NOT NULL DEFAULT 1", c.buildRow()) + }) + + t.Run("it builds with on_update setter", func(t *testing.T) { + c := Binary{OnUpdate: "set null"} + assert.Equal(t, "varbinary NOT NULL ON UPDATE set null", c.buildRow()) + }) + + t.Run("it builds with comment", func(t *testing.T) { + c := Binary{Comment: "test"} + assert.Equal(t, "varbinary NOT NULL COMMENT 'test'", c.buildRow()) + }) + + t.Run("it builds string with all parameters", func(t *testing.T) { + c := Binary{ + Fixed: true, + Precision: 36, + Nullable: true, + Default: "1", + OnUpdate: "set null", + Comment: "test", + } + + assert.Equal( + t, + "binary(36) NULL DEFAULT 1 ON UPDATE set null COMMENT 'test'", + c.buildRow(), + ) + }) +} diff --git a/foreign.go b/foreign.go new file mode 100644 index 0000000..d24c94e --- /dev/null +++ b/foreign.go @@ -0,0 +1,58 @@ +// Package migrator represents MySQL database migrator +package migrator + +import ( + "fmt" + "strings" +) + +type foreigns []foreign + +func (f foreigns) render() string { + values := []string{} + + for _, foreign := range f { + values = append(values, foreign.render()) + } + + return strings.Join(values, ", ") +} + +type foreign struct { + key string + column string + reference string // reference field + on string // reference table + onUpdate string + onDelete string +} + +func (f foreign) render() string { + if f.key == "" || f.column == "" || f.on == "" || f.reference == "" { + return "" + } + + sql := fmt.Sprintf("CONSTRAINT `%s` FOREIGN KEY (`%s`) REFERENCES `%s` (`%s`)", f.key, f.column, f.on, f.reference) + if referenceOptions.has(strings.ToUpper(f.onDelete)) { + sql += " ON DELETE " + strings.ToUpper(f.onDelete) + } + if referenceOptions.has(strings.ToUpper(f.onUpdate)) { + sql += " ON UPDATE " + strings.ToUpper(f.onUpdate) + } + + return sql +} + +var referenceOptions = list{"SET NULL", "CASCADE", "RESTRICT", "NO ACTION", "SET DEFAULT"} + +type list []string + +func (l list) has(value string) bool { + for _, item := range l { + if item == value { + return true + } + } + + return false +} diff --git a/foreign_test.go b/foreign_test.go new file mode 100644 index 0000000..09836ed --- /dev/null +++ b/foreign_test.go @@ -0,0 +1,72 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestForeigns(t *testing.T) { + t.Run("it returns empty on empty keys", func(t *testing.T) { + f := foreigns{foreign{}} + + assert.Equal(t, "", f.render()) + }) + + t.Run("it renders row from one foreign", func(t *testing.T) { + f := foreigns{foreign{key: "idx_foreign", column: "test_id", reference: "id", on: "tests"}} + + assert.Equal(t, "CONSTRAINT `idx_foreign` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`)", f.render()) + }) + + t.Run("it renders row from multiple foreigns", func(t *testing.T) { + f := foreigns{ + foreign{key: "idx_foreign", column: "test_id", reference: "id", on: "tests"}, + foreign{key: "foreign_idx", column: "random_id", reference: "id", on: "randoms"}, + } + + assert.Equal( + t, + "CONSTRAINT `idx_foreign` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`), CONSTRAINT `foreign_idx` FOREIGN KEY (`random_id`) REFERENCES `randoms` (`id`)", + f.render(), + ) + }) +} + +func TestForeign(t *testing.T) { + t.Run("it builds base constraint", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`)", f.render()) + }) + + t.Run("it builds contraint with on_update", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests", onUpdate: "no action"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`) ON UPDATE NO ACTION", f.render()) + }) + + t.Run("it builds contraint without invalid on_update", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests", onUpdate: "null"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`)", f.render()) + }) + + t.Run("it builds contraint with on_update", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests", onDelete: "set default"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`) ON DELETE SET DEFAULT", f.render()) + }) + + t.Run("it builds contraint without invalid on_update", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests", onDelete: "default"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`)", f.render()) + }) + + t.Run("it builds full contraint", func(t *testing.T) { + f := foreign{key: "foreign_idx", column: "test_id", reference: "id", on: "tests", onUpdate: "cascade", onDelete: "restrict"} + + assert.Equal(t, "CONSTRAINT `foreign_idx` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`) ON DELETE RESTRICT ON UPDATE CASCADE", f.render()) + }) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..92b57c5 --- /dev/null +++ b/go.mod @@ -0,0 +1,8 @@ +module github.com/larapulse/migrator + +go 1.13 + +require ( + github.com/DATA-DOG/go-sqlmock v1.4.1 + github.com/stretchr/testify v1.6.1 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f5ac1e8 --- /dev/null +++ b/go.sum @@ -0,0 +1,14 @@ +github.com/DATA-DOG/go-sqlmock v1.4.1 h1:ThlnYciV1iM/V0OSF/dtkqWb6xo5qITT1TJBG1MRDJM= +github.com/DATA-DOG/go-sqlmock v1.4.1/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/key.go b/key.go new file mode 100644 index 0000000..ad49a09 --- /dev/null +++ b/key.go @@ -0,0 +1,48 @@ +// Package migrator represents MySQL database migrator +package migrator + +import "strings" + +type keys []key + +func (k keys) render() string { + values := []string{} + + for _, key := range k { + value := key.render() + if value != "" { + values = append(values, value) + } + } + + return strings.Join(values, ", ") +} + +type key struct { + name string + typ string // primary, unique + columns []string +} + +var keyTypes = list{"PRIMARY", "UNIQUE"} + +func (k key) render() string { + if len(k.columns) == 0 { + return "" + } + + sql := "" + if keyTypes.has(strings.ToUpper(k.typ)) { + sql += strings.ToUpper(k.typ) + " " + } + + sql += "KEY" + + if k.name != "" { + sql += " `" + k.name + "`" + } + + sql += " (`" + strings.Join(k.columns, "`, `") + "`)" + + return sql +} diff --git a/key_test.go b/key_test.go new file mode 100644 index 0000000..1cdc950 --- /dev/null +++ b/key_test.go @@ -0,0 +1,66 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeys(t *testing.T) { + t.Run("it returns empty on empty keys", func(t *testing.T) { + k := keys{key{}} + + assert.Equal(t, "", k.render()) + }) + + t.Run("it renders row from one key", func(t *testing.T) { + k := keys{key{columns: []string{"test_id"}}} + + assert.Equal(t, "KEY (`test_id`)", k.render()) + }) + + t.Run("it renders row from multiple keys", func(t *testing.T) { + k := keys{ + key{columns: []string{"test_id"}}, + key{columns: []string{"random_id"}}, + } + + assert.Equal( + t, + "KEY (`test_id`), KEY (`random_id`)", + k.render(), + ) + }) +} + +func TestKey(t *testing.T) { + t.Run("it returns empty on empty keys", func(t *testing.T) { + k := key{} + + assert.Equal(t, "", k.render()) + }) + + t.Run("it skips type if it is not in valid list", func(t *testing.T) { + k := key{typ: "random", columns: []string{"test_id"}} + + assert.Equal(t, "KEY (`test_id`)", k.render()) + }) + + t.Run("it renders with type", func(t *testing.T) { + k := key{typ: "primary", columns: []string{"test_id"}} + + assert.Equal(t, "PRIMARY KEY (`test_id`)", k.render()) + }) + + t.Run("it renders with multiple columns", func(t *testing.T) { + k := key{typ: "unique", columns: []string{"test_id", "random_id"}} + + assert.Equal(t, "UNIQUE KEY (`test_id`, `random_id`)", k.render()) + }) + + t.Run("it renders with name", func(t *testing.T) { + k := key{name: "random_idx", columns: []string{"test_id"}} + + assert.Equal(t, "KEY `random_idx` (`test_id`)", k.render()) + }) +} diff --git a/logo.png b/logo.png new file mode 100644 index 0000000..1b62af7 Binary files /dev/null and b/logo.png differ diff --git a/migration.go b/migration.go new file mode 100644 index 0000000..d26148d --- /dev/null +++ b/migration.go @@ -0,0 +1,88 @@ +// Package migrator represents MySQL database migrator +package migrator + +import "database/sql" + +type executableSQL interface { + Exec(query string, args ...interface{}) (sql.Result, error) +} + +// Migration represents migration entity +// +// Name should be unique name to specify migration. It is up to you to choose the name you like +// Up() should return Schema with prepared commands to be migrated +// Down() should return Schema with prepared commands to be reverted +// Transaction optinal flag to enable transaction for migration +// +// Example: +// var migration = migrator.Migration{ +// Name: "19700101_0001_create_posts_table", +// Up: func() migrator.Schema { +// var s migrator.Schema +// posts := migrator.Table{Name: "posts"} +// +// posts.UniqueID("id") +// posts.Column("title", migrator.String{Precision: 64}) +// posts.Column("content", migrator.Text{}) +// posts.Timestamps() +// +// s.CreateTable(posts) +// +// return s +// }, +// Down: func() migrator.Schema { +// var s migrator.Schema +// +// s.DropTable("posts") +// +// return s +// }, +// } +type Migration struct { + Name string + Up func() Schema + Down func() Schema + Transaction bool +} + +func (m Migration) exec(db *sql.DB, commands ...command) error { + if m.Transaction { + return runInTransaction(db, commands...) + } + + return run(db, commands...) +} + +func runInTransaction(db *sql.DB, commands ...command) error { + tx, err := db.Begin() + if err != nil { + return err + } + + err = run(tx, commands...) + if err != nil { + tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return err + } + + return nil +} + +func run(db executableSQL, commands ...command) error { + for _, command := range commands { + sql := command.toSQL() + if sql == "" { + return ErrNoSQLCommandsToRun + } + if _, err := db.Exec(sql); err != nil { + return err + } + } + + return nil +} diff --git a/migration_test.go b/migration_test.go new file mode 100644 index 0000000..33c7fac --- /dev/null +++ b/migration_test.go @@ -0,0 +1,204 @@ +package migrator + +import ( + "database/sql" + "errors" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +var ( + errTestDBExecFailed = errors.New("DB exec command failed") + errTestDBQueryFailed = errors.New("DB query command failed") + errTestDBTransactionFailed = errors.New("DB transaction failed") + errTestLastInsertID = errors.New("Failed to get last insert ID") + errTestAffectedRows = errors.New("Failed to amount of affected rows") +) + +type testDummyCommand string + +func (c testDummyCommand) toSQL() string { + return string(c) +} + +func testDBConnection(t *testing.T) (db *sql.DB, mock sqlmock.Sqlmock, resetDB func()) { + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) + } + + resetDB = func() { + defer db.Close() + + // we make sure that all expectations were met + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("there were unfulfilled expectations: %s", err) + } + } + + return +} + +func TestMigrationExec(t *testing.T) { + t.Run("it executes migration in transaction", func(t *testing.T) { + m := Migration{Transaction: true} + + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand("test"), + } + + mock.ExpectBegin() + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(commands[1].toSQL()).WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + + // now we execute our method + if err := m.exec(db, commands...); err != nil { + t.Errorf("error was not expected while running query: %s", err) + } + }) + + t.Run("it executes general transaction", func(t *testing.T) { + m := Migration{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand("test"), + } + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(commands[1].toSQL()).WillReturnResult(sqlmock.NewResult(2, 1)) + + // now we execute our method + if err := m.exec(db, commands...); err != nil { + t.Errorf("error was not expected while running query: %s", err) + } + }) +} + +func TestRunInTransaction(t *testing.T) { + t.Run("it returns an error if transaction wasn't started", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{} + want := sqlmock.ErrCancelled + mock.ExpectBegin().WillReturnError(want) + + // now we execute our method + got := runInTransaction(db, commands...) + assert.Equal(t, want, got) + }) + + t.Run("it rolled back transaction in case of error", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{testDummyCommand("run")} + want := sqlmock.ErrCancelled + + mock.ExpectBegin() + mock.ExpectExec(commands[0].toSQL()).WillReturnError(want) + mock.ExpectRollback() + + // now we execute our method + got := runInTransaction(db, commands...) + assert.Equal(t, want, got) + }) + + t.Run("it returns an error if committing transaction was unsuccessful", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{} + want := sqlmock.ErrCancelled + + mock.ExpectBegin() + mock.ExpectCommit().WillReturnError(want) + + // now we execute our method + got := runInTransaction(db, commands...) + assert.Equal(t, want, got) + }) + + t.Run("it executes all commands", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand("test"), + } + + mock.ExpectBegin() + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(commands[1].toSQL()).WillReturnResult(sqlmock.NewResult(2, 1)) + mock.ExpectCommit() + + // now we execute our method + if err := runInTransaction(db, commands...); err != nil { + t.Errorf("error was not expected while running query: %s", err) + } + }) +} + +func TestRun(t *testing.T) { + t.Run("it returns an error on invalid command", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand(""), + } + + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + + err := run(db, commands...) + + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it returns an error on DB command execution", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand("dead"), + } + + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(commands[1].toSQL()).WillReturnError(errTestDBExecFailed) + + err := run(db, commands...) + + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it executes all commands", func(t *testing.T) { + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + commands := []command{ + testCommand("test"), + testDummyCommand("test"), + } + + mock.ExpectExec(commands[0].toSQL()).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(commands[1].toSQL()).WillReturnResult(sqlmock.NewResult(2, 1)) + + err := run(db, commands...) + + assert.Nil(t, err) + }) +} diff --git a/migrator.go b/migrator.go new file mode 100644 index 0000000..dbb58d2 --- /dev/null +++ b/migrator.go @@ -0,0 +1,309 @@ +// Package migrator represents MySQL database migrator +package migrator + +import ( + "database/sql" + "errors" + "fmt" + "strings" + "time" +) + +const migrationTable = "migrations" + +var ( + // ErrTableNotExists returned when migration table not found + ErrTableNotExists = errors.New("Migration table does not exists") + + // ErrNoMigrationDefined returned when no migations defined in the migrations pool + ErrNoMigrationDefined = errors.New("No migrations defined") + + // ErrEmptyRollbackStack returned when nothing can be reverted + ErrEmptyRollbackStack = errors.New("Nothing to rollback, there are no migration executed") + + // ErrMissingMigrationName returned when migration name is missing + ErrMissingMigrationName = errors.New("Missing migration name") + + // ErrNoSQLCommandsToRun returned when migration is invalid and has not commands in the pool + ErrNoSQLCommandsToRun = errors.New("There is no command to be executed") +) + +type migrationEntry struct { + id uint64 + name string + batch uint64 + appliedAt time.Time +} + +// Migrator represents a struct with migrations, that should be executed +// +// Default migration table name is `migrations`, but it can be re-defined +// Pool is a list of migrations that should be migrated +type Migrator struct { + // Name of the table to track executed migrations + TableName string + // stack of migrations + Pool []Migration + executed []migrationEntry +} + +// Migrate run all migrations from pool and stores in migration table executed migration +func (m Migrator) Migrate(db *sql.DB) (migrated []string, err error) { + if len(m.Pool) == 0 { + return migrated, ErrNoMigrationDefined + } + + if err := m.checkMigrationPool(); err != nil { + return migrated, err + } + + if err := m.createMigrationTable(db); err != nil { + return migrated, fmt.Errorf("Migration table failed to be created: %v", err) + } + + if err := m.fetchExecuted(db); err != nil { + return migrated, err + } + + batch := m.batch() + 1 + table := m.table() + + for _, item := range m.Pool { + if m.isExecuted(item.Name) { + continue + } + + s := item.Up() + if len(s.pool) == 0 { + return migrated, ErrNoSQLCommandsToRun + } + if err := item.exec(db, s.pool...); err != nil { + return migrated, err + } + + entry := migrationEntry{name: item.Name, batch: batch} + sql := fmt.Sprintf("INSERT INTO `%s` (`name`, `batch`) VALUES (\"%s\", %d)", table, entry.name, entry.batch) + + if _, err := db.Exec(sql); err != nil { + return migrated, err + } + + migrated = append(migrated, item.Name) + } + + return migrated, nil +} + +// Rollback reverts last executed batch of migratios +func (m Migrator) Rollback(db *sql.DB) (reverted []string, err error) { + if len(m.Pool) == 0 { + return reverted, ErrNoMigrationDefined + } + + if err := m.checkMigrationPool(); err != nil { + return reverted, err + } + + if !m.hasTable(db) { + return reverted, ErrTableNotExists + } + + if err := m.fetchExecuted(db); err != nil { + return reverted, err + } + + if len(m.executed) == 0 { + return reverted, ErrEmptyRollbackStack + } + + table := m.table() + revertable := m.lastBatchExecuted() + + for i := len(revertable) - 1; i >= 0; i-- { + name := revertable[i].name + + for j := len(m.Pool) - 1; j >= 0; j-- { + item := m.Pool[j] + + if item.Name == name { + s := item.Down() + if len(s.pool) == 0 { + return reverted, ErrNoSQLCommandsToRun + } + if err := item.exec(db, s.pool...); err != nil { + return reverted, err + } + + if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE id = ?", table), revertable[i].id); err != nil { + return reverted, err + } + + reverted = append(reverted, name) + } + } + } + + return reverted, nil +} + +// Revert reverts all executed migration from the pool +func (m Migrator) Revert(db *sql.DB) (reverted []string, err error) { + if len(m.Pool) == 0 { + return reverted, ErrNoMigrationDefined + } + + if err := m.checkMigrationPool(); err != nil { + return reverted, err + } + + if !m.hasTable(db) { + return reverted, ErrTableNotExists + } + + if err := m.fetchExecuted(db); err != nil { + return reverted, err + } + + if len(m.executed) == 0 { + return reverted, ErrEmptyRollbackStack + } + + table := m.table() + + for i := len(m.executed) - 1; i >= 0; i-- { + name := m.executed[i].name + + for j := len(m.Pool) - 1; j >= 0; j-- { + item := m.Pool[j] + + if item.Name == name { + s := item.Down() + if len(s.pool) == 0 { + return reverted, ErrNoSQLCommandsToRun + } + if err := item.exec(db, s.pool...); err != nil { + return reverted, err + } + + if _, err := db.Exec(fmt.Sprintf("DELETE FROM %s WHERE id = ?", table), m.executed[i].id); err != nil { + return reverted, err + } + + reverted = append(reverted, name) + } + } + } + + return reverted, nil +} + +func (m Migrator) checkMigrationPool() error { + var names []string + + for _, item := range m.Pool { + if item.Name == "" { + return ErrMissingMigrationName + } + + for _, exist := range names { + if exist == item.Name { + return fmt.Errorf(`Migration "%s" is duplicated in the pool`, exist) + } + } + + names = append(names, item.Name) + } + + return nil +} + +func (m Migrator) createMigrationTable(db *sql.DB) error { + if m.hasTable(db) { + return nil + } + + sql := fmt.Sprintf( + "CREATE TABLE %s (%s) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + m.table(), + strings.Join([]string{ + "id int(10) unsigned NOT NULL AUTO_INCREMENT PRIMARY KEY", + "name varchar(255) COLLATE utf8mb4_unicode_ci NOT NULL", + "batch int(11) NOT NULL", + "applied_at timestamp NULL DEFAULT CURRENT_TIMESTAMP", + }, ", "), + ) + + _, err := db.Exec(sql) + + return err +} + +func (m Migrator) hasTable(db *sql.DB) bool { + _, hasTable := db.Query("SELECT * FROM " + m.table()) + + return hasTable == nil +} + +func (m Migrator) table() string { + table := m.TableName + if table == "" { + table = migrationTable + } + + return table +} + +func (m Migrator) batch() uint64 { + var batch uint64 + + for _, item := range m.executed { + if item.batch > batch { + batch = item.batch + } + } + + return batch +} + +func (m *Migrator) fetchExecuted(db *sql.DB) error { + rows, err := db.Query("SELECT id, name, batch, applied_at FROM " + m.table() + " ORDER BY applied_at ASC") + if err != nil { + return err + } + m.executed = []migrationEntry{} + + for rows.Next() { + var entry migrationEntry + + if err := rows.Scan(&entry.id, &entry.name, &entry.batch, &entry.appliedAt); err != nil { + return err + } + + m.executed = append(m.executed, entry) + } + + return nil +} + +func (m Migrator) isExecuted(name string) bool { + for _, item := range m.executed { + if item.name == name { + return true + } + } + + return false +} + +func (m Migrator) lastBatchExecuted() []migrationEntry { + batch := m.batch() + var result []migrationEntry + + for _, item := range m.executed { + if item.batch == batch { + result = append(result, item) + } + } + + return result +} diff --git a/migrator_test.go b/migrator_test.go new file mode 100644 index 0000000..0f17e14 --- /dev/null +++ b/migrator_test.go @@ -0,0 +1,841 @@ +package migrator + +import ( + "fmt" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/assert" +) + +func TestMigrate(t *testing.T) { + t.Run("it fails when migration pool is empty", func(t *testing.T) { + m := Migrator{} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoMigrationDefined, err) + }) + + t.Run("it fails when there is invalid item in the migration pool", func(t *testing.T) { + migration := Migration{} + m := Migrator{Pool: []Migration{migration}} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, ErrMissingMigrationName, err) + }) + + t.Run("it fails when migration table creation failed", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows().WillReturnError(errTestDBQueryFailed) + mock.ExpectExec("CREATE").WillReturnError(errTestDBExecFailed) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, fmt.Errorf("Migration table failed to be created: %v", errTestDBExecFailed), err) + }) + + t.Run("it fails while fetching executed list", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnError(errTestDBExecFailed) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it skips execution when it was already executed", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Nil(t, err) + }) + + t.Run("it fails executing empty list of migrations", func(t *testing.T) { + migration := Migration{Name: "test", Up: func() Schema { + var s Schema + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails executing migration commands", func(t *testing.T) { + migration := Migration{Name: "test", Up: func() Schema { + var s Schema + s.pool = append(s.pool, testDummyCommand("")) + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails while storing executed migration info", func(t *testing.T) { + migration := Migration{Name: "test", Up: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("INSERT").WillReturnError(errTestDBExecFailed) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it executes migrations and returns list of migrated items", func(t *testing.T) { + migration := Migration{Name: "test", Up: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 4, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec(`INSERT .* VALUES \("test", 5\)`).WillReturnResult(sqlmock.NewResult(1, 1)) + + migrated, err := m.Migrate(db) + + assert.Len(t, migrated, 1) + assert.Equal(t, migrated[0], "test") + assert.Nil(t, err) + }) +} + +func TestRollback(t *testing.T) { + t.Run("it fails when migration pool is empty", func(t *testing.T) { + m := Migrator{} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoMigrationDefined, err) + }) + + t.Run("it fails when there is invalid item in the migration pool", func(t *testing.T) { + migration := Migration{} + m := Migrator{Pool: []Migration{migration}} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrMissingMigrationName, err) + }) + + t.Run("it fails when migration table missing", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows().WillReturnError(errTestDBQueryFailed) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrTableNotExists, err) + }) + + t.Run("it fails while fetching executed list", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnError(errTestDBExecFailed) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it exits when executed list is empty", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(sqlmock.NewRows([]string{})) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrEmptyRollbackStack, err) + }) + + t.Run("it does nothing when executed migration not in the migration pool", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.pool = append(s.pool, testDummyCommand("")) + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Nil(t, err) + }) + + t.Run("it fails executing empty list of commands", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails executing migration commands", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.pool = append(s.pool, testDummyCommand("")) + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails while removing executed migration info", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DELETE").WillReturnError(errTestDBExecFailed) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it roll back migrations and returns list of reverted items", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }} + m := Migrator{Pool: []Migration{migration, {Name: "new"}}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}). + AddRow(1, "test", 4, time.Now()). + AddRow(2, "new", 3, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DELETE FROM migrations WHERE id = ?").WithArgs(1).WillReturnResult(sqlmock.NewResult(1, 1)) + + reverted, err := m.Rollback(db) + + assert.Len(t, reverted, 1) + assert.Equal(t, reverted[0], "test") + assert.Nil(t, err) + }) +} + +func TestRevert(t *testing.T) { + t.Run("it fails when migration pool is empty", func(t *testing.T) { + m := Migrator{} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoMigrationDefined, err) + }) + + t.Run("it fails when there is invalid item in the migration pool", func(t *testing.T) { + migration := Migration{} + m := Migrator{Pool: []Migration{migration}} + db, _, resetDB := testDBConnection(t) + defer resetDB() + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrMissingMigrationName, err) + }) + + t.Run("it fails when migration table missing", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows().WillReturnError(errTestDBQueryFailed) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrTableNotExists, err) + }) + + t.Run("it fails while fetching executed list", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnError(errTestDBExecFailed) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it exits when executed list is empty", func(t *testing.T) { + migration := Migration{Name: "test"} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(sqlmock.NewRows([]string{})) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrEmptyRollbackStack, err) + }) + + t.Run("it does nothing when executed migration not in the migration pool", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.pool = append(s.pool, testDummyCommand("")) + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "new", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Nil(t, err) + }) + + t.Run("it fails executing empty list of commands", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails executing migration commands", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.pool = append(s.pool, testDummyCommand("")) + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, ErrNoSQLCommandsToRun, err) + }) + + t.Run("it fails while removing executed migration info", func(t *testing.T) { + migration := Migration{Name: "test", Down: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }} + m := Migrator{Pool: []Migration{migration}} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}).AddRow(1, "test", 1, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DELETE").WillReturnError(errTestDBExecFailed) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 0) + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) + + t.Run("it roll back migrations and returns list of reverted items", func(t *testing.T) { + m := Migrator{Pool: []Migration{ + {Name: "test", Down: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }}, + {Name: "new", Down: func() Schema { + var s Schema + s.DropTable("test", false, "") + return s + }}, + }} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}). + AddRow(1, "test", 4, time.Now()). + AddRow(2, "new", 3, time.Now()) + + mock.ExpectQuery("SELECT").WillReturnRows() + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DELETE FROM migrations WHERE id = ?").WithArgs(2).WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DROP").WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("DELETE FROM migrations WHERE id = ?").WithArgs(1).WillReturnResult(sqlmock.NewResult(1, 1)) + + reverted, err := m.Revert(db) + + assert.Len(t, reverted, 2) + assert.Equal(t, reverted[0], "new") + assert.Equal(t, reverted[1], "test") + assert.Nil(t, err) + }) +} + +func TestCheckMigrationPool(t *testing.T) { + t.Run("it is successful on empty pool", func(t *testing.T) { + m := Migrator{} + err := m.checkMigrationPool() + + assert.Nil(t, err) + }) + + t.Run("It is successful for proper pool", func(t *testing.T) { + m := Migrator{Pool: []Migration{ + {Name: "test"}, + {Name: "random"}, + }} + err := m.checkMigrationPool() + + assert.Nil(t, err) + }) + + t.Run("it returns an error on missing migration name", func(t *testing.T) { + m := Migrator{Pool: []Migration{ + {Name: "test"}, + {Name: "random"}, + {Name: ""}, + }} + err := m.checkMigrationPool() + + assert.Error(t, err) + assert.Equal(t, ErrMissingMigrationName, err) + }) + + t.Run("it returns an error on duplicated migration name", func(t *testing.T) { + m := Migrator{Pool: []Migration{ + {Name: "test"}, + {Name: "random"}, + {Name: "again"}, + {Name: "migration"}, + {Name: "again"}, + }} + err := m.checkMigrationPool() + + assert.NotNil(t, err) + assert.Equal(t, `Migration "again" is duplicated in the pool`, err.Error()) + }) +} + +func TestCreateMigrationTable(t *testing.T) { + t.Run("it ignores creation if table exists", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery(`SELECT \* FROM migrations`).WillReturnRows().WillReturnError(nil) + + err := m.createMigrationTable(db) + + assert.Nil(t, err) + }) + + t.Run("it creates migration table", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery(`SELECT \* FROM migrations`).WillReturnError(errTestDBQueryFailed) + sql := `CREATE TABLE migrations \(id int\(10\) unsigned NOT NULL AUTO_INCREMENT PRIMARY KEY, name varchar\(255\) COLLATE utf8mb4_unicode_ci NOT NULL, batch int\(11\) NOT NULL, applied_at timestamp NULL DEFAULT CURRENT_TIMESTAMP\) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` + mock.ExpectExec(sql).WillReturnResult(sqlmock.NewResult(1, 1)) + + err := m.createMigrationTable(db) + + assert.Nil(t, err) + }) + + t.Run("it fails creating table", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery(`SELECT \* FROM migrations`).WillReturnError(errTestDBQueryFailed) + sql := `CREATE TABLE migrations \(` + + `id int\(10\) unsigned NOT NULL AUTO_INCREMENT PRIMARY KEY, ` + + `name varchar\(255\) COLLATE utf8mb4_unicode_ci NOT NULL, ` + + `batch int\(11\) NOT NULL, applied_at timestamp NULL DEFAULT CURRENT_TIMESTAMP\) ` + + `ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci` + mock.ExpectExec(sql).WillReturnError(errTestDBExecFailed) + + err := m.createMigrationTable(db) + + assert.Error(t, err) + assert.Equal(t, errTestDBExecFailed, err) + }) +} + +func TestHasTable(t *testing.T) { + t.Run("it returns true if table exists", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery(`SELECT \* FROM migrations`).WillReturnRows().WillReturnError(nil) + got := m.hasTable(db) + + assert.Equal(t, true, got) + }) + + t.Run("it returns false if table does not exists", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery(`SELECT \* FROM migrations`).WillReturnError(errTestDBQueryFailed) + got := m.hasTable(db) + + assert.Equal(t, false, got) + }) +} + +func TestMigrationTable(t *testing.T) { + t.Run("it returns default table name", func(t *testing.T) { + m := Migrator{} + got := m.table() + + assert.Equal(t, "migrations", got) + }) + + t.Run("it returns selected table name", func(t *testing.T) { + m := Migrator{TableName: "table"} + got := m.table() + + assert.Equal(t, "table", got) + }) +} + +func TestBatch(t *testing.T) { + t.Run("it returns zero on empty executed list", func(t *testing.T) { + m := Migrator{} + got := m.batch() + + assert.Equal(t, uint64(0), got) + }) + + t.Run("it returns zero if migration batch is zero", func(t *testing.T) { + m := Migrator{ + executed: []migrationEntry{ + {batch: uint64(0)}, + }, + } + got := m.batch() + + assert.Equal(t, uint64(0), got) + }) + + t.Run("it returns the biggest batch from migration list", func(t *testing.T) { + m := Migrator{ + executed: []migrationEntry{ + {batch: uint64(6)}, + {batch: uint64(3)}, + {batch: uint64(15)}, + {batch: uint64(12)}, + }, + } + got := m.batch() + + assert.Equal(t, uint64(15), got) + }) +} + +func TestPoolExecuted(t *testing.T) { + t.Run("it fails executing query", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnError(errTestDBQueryFailed) + + err := m.fetchExecuted(db) + + assert.Error(t, err) + assert.Equal(t, errTestDBQueryFailed, err) + assert.Nil(t, m.executed) + }) + + t.Run("it fails scanning row", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}). + AddRow(1, "first", 1, time.Now()). + AddRow(2, "second", 1, "test") + + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + got := m.fetchExecuted(db) + + assert.Error(t, got) + assert.NotNil(t, m.executed) + assert.Len(t, m.executed, 1) + }) + + t.Run("it returns a list of executed migrations", func(t *testing.T) { + m := Migrator{} + db, mock, resetDB := testDBConnection(t) + defer resetDB() + + rows := sqlmock.NewRows([]string{"id", "name", "batch", "applied_at"}). + AddRow(1, "first", 1, time.Now()). + AddRow(2, "second", 1, time.Now()) + + mock.ExpectQuery("SELECT id, name, batch, applied_at FROM migrations").WillReturnRows(rows) + + err := m.fetchExecuted(db) + + assert.Nil(t, err) + assert.NotNil(t, m.executed) + assert.Len(t, m.executed, 2) + }) +} + +func TestIsExecuted(t *testing.T) { + t.Run("it returns false on empty executed list", func(t *testing.T) { + m := Migrator{} + got := m.isExecuted("test") + + assert.Equal(t, false, got) + }) + + t.Run("it returns false if migration wasn't executed yet", func(t *testing.T) { + m := Migrator{ + executed: []migrationEntry{ + {name: "test"}, + {name: "random"}, + {name: "lorem"}, + {name: "ipsum"}, + }, + } + got := m.isExecuted("") + + assert.Equal(t, false, got) + }) + + t.Run("it returns true if migration was executed", func(t *testing.T) { + m := Migrator{ + executed: []migrationEntry{ + {name: "test"}, + {name: "random"}, + {name: "lorem"}, + {name: "ipsum"}, + }, + } + got := m.isExecuted("random") + + assert.Equal(t, true, got) + }) +} + +func TestLastExecutedForBatch(t *testing.T) { + t.Run("it returns an empty list if nothing found for biggest batch", func(t *testing.T) { + m := Migrator{} + got := m.lastBatchExecuted() + + assert.Len(t, got, 0) + }) + + t.Run("", func(t *testing.T) { + m := Migrator{ + executed: []migrationEntry{ + {name: "test", batch: 1}, + {name: "again", batch: 3}, + {name: "random", batch: 2}, + {name: "lorem", batch: 3}, + {name: "ipsum", batch: 3}, + }, + } + got := m.lastBatchExecuted() + + assert.Len(t, got, 3) + }) +} diff --git a/schema.go b/schema.go new file mode 100644 index 0000000..977a8ad --- /dev/null +++ b/schema.go @@ -0,0 +1,68 @@ +// Package migrator represents MySQL database migrator +package migrator + +// Schema allows to add commands on schema. +// It should be used within migration to add migration commands. +type Schema struct { + pool []command +} + +// CreateTable allows to create table in schema +// +// Example: +// var s migrator.Schema +// t := migrator.Table{Name: "test"} +// +// s.CreateTable(t) +func (s *Schema) CreateTable(t Table) { + s.pool = append(s.pool, createTableCommand{t}) +} + +// DropTable removes table from schema +// Warning ⚠️ BC incompatible +// +// Example: +// var s migrator.Schema +// s.DropTable("test", false, "") +// +// Soft delete (drop if exists) +// s.DropTable("test", true, "") +func (s *Schema) DropTable(name string, soft bool, option string) { + s.pool = append(s.pool, dropTableCommand{name, soft, option}) +} + +// RenameTable executes command to rename table +// Warning ⚠️ BC incompatible +// +// Example: +// var s migrator.Schema +// s.RenameTable("old", "new") +func (s *Schema) RenameTable(old string, new string) { + s.pool = append(s.pool, renameTableCommand{old: old, new: new}) +} + +// AlterTable makes changes on table level +// +// Example: +// var s migrator.Schema +// var c TableCommands +// s.AlterTable("test", c) +func (s *Schema) AlterTable(name string, c TableCommands) { + s.pool = append(s.pool, alterTableCommand{name, c}) +} + +// CustomCommand allows to add custom command to the Schema +// +// Example: +// type customCommand string +// +// func (c customCommand) toSQL() string { +// return string(c) +// } +// +// c := customCommand("DROP PROCEDURE abc") +// var s migrator.Schema +// s.CustomCommand(c) +func (s *Schema) CustomCommand(c command) { + s.pool = append(s.pool, c) +} diff --git a/schema_command.go b/schema_command.go new file mode 100644 index 0000000..2324b5c --- /dev/null +++ b/schema_command.go @@ -0,0 +1,117 @@ +// Package migrator represents MySQL database migrator +package migrator + +import ( + "fmt" + "strings" +) + +type command interface { + toSQL() string +} + +type createTableCommand struct { + t Table +} + +func (c createTableCommand) toSQL() string { + if c.t.Name == "" { + return "" + } + + context := c.t.columns.render() + if context == "" { + context = "`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT" + } + + if res := c.t.indexes.render(); res != "" { + context += ", " + res + } + + if res := c.t.foreigns.render(); res != "" { + context += ", " + res + } + + engine := c.t.Engine + if engine == "" { + engine = "InnoDB" + } + + charset := c.t.Charset + collation := c.t.Collation + if charset == "" && collation == "" { + charset = "utf8mb4" + collation = "utf8mb4_unicode_ci" + } + if charset == "" && collation != "" { + parts := strings.Split(collation, "_") + charset = parts[0] + } + if charset != "" && collation == "" { + collation = charset + "_unicode_ci" + } + + return fmt.Sprintf( + "CREATE TABLE `%s` (%s) ENGINE=%s DEFAULT CHARSET=%s COLLATE=%s", + c.t.Name, + context, + engine, + charset, + collation, + ) +} + +type dropTableCommand struct { + table string + soft bool + option string +} + +func (c dropTableCommand) toSQL() string { + sql := "DROP TABLE" + + if c.soft { + sql += " IF EXISTS" + } + + sql += fmt.Sprintf(" `%s`", c.table) + + var validOptions = list{"RESTRICT", "CASCADE"} + if validOptions.has(strings.ToUpper(c.option)) { + sql += " " + strings.ToUpper(c.option) + } + + return sql +} + +type renameTableCommand struct { + old string + new string +} + +func (c renameTableCommand) toSQL() string { + return fmt.Sprintf("RENAME TABLE `%s` TO `%s`", c.old, c.new) +} + +type alterTableCommand struct { + name string + pool TableCommands +} + +func (c alterTableCommand) toSQL() string { + if c.name == "" || len(c.pool) == 0 { + return "" + } + + return "ALTER TABLE `" + c.name + "` " + c.poolToSQL() +} + +func (c alterTableCommand) poolToSQL() string { + var sql []string + + for _, tc := range c.pool { + sql = append(sql, tc.toSQL()) + } + + return strings.Join(sql, ", ") +} diff --git a/schema_command_test.go b/schema_command_test.go new file mode 100644 index 0000000..1e3b063 --- /dev/null +++ b/schema_command_test.go @@ -0,0 +1,232 @@ +package migrator + +import ( + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testCommand string + +func (c testCommand) toSQL() string { + return "Do action on " + string(c) +} + +func TestCreateTableCommand(t *testing.T) { + t.Run("it returns empty string when table name missing", func(t *testing.T) { + tb := Table{} + c := createTableCommand{tb} + + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it renders default table", func(t *testing.T) { + tb := Table{Name: "test"} + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + c.toSQL(), + ) + }) + + t.Run("it renders columns", func(t *testing.T) { + tb := Table{ + Name: "test", + columns: []column{ + {"test", testColumnType("random thing")}, + {"random", testColumnType("another thing")}, + }, + } + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`test` random thing, `random` another thing) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + c.toSQL(), + ) + }) + + t.Run("it renders indexes", func(t *testing.T) { + tb := Table{ + Name: "test", + indexes: []key{ + {name: "idx_rand", columns: []string{"id"}}, + {columns: []string{"id", "name"}}, + }, + } + c := createTableCommand{tb} + + assert.Equal( + t, + strings.Join([]string{ + "CREATE TABLE `test` (", + "`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, ", + "KEY `idx_rand` (`id`), KEY (`id`, `name`)", + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + }, ""), + c.toSQL(), + ) + }) + + t.Run("it renders foreigns", func(t *testing.T) { + tb := Table{ + Name: "test", + foreigns: []foreign{ + {key: "idx_foreign", column: "test_id", reference: "id", on: "tests"}, + {key: "foreign_idx", column: "random_id", reference: "id", on: "randoms"}, + }, + } + c := createTableCommand{tb} + + assert.Equal( + t, + strings.Join([]string{ + "CREATE TABLE `test` (", + "`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT, ", + "CONSTRAINT `idx_foreign` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`), ", + "CONSTRAINT `foreign_idx` FOREIGN KEY (`random_id`) REFERENCES `randoms` (`id`)", + ") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + }, ""), + c.toSQL(), + ) + }) + + t.Run("it renders engine", func(t *testing.T) { + tb := Table{Name: "test", Engine: "MyISAM"} + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT) ENGINE=MyISAM DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci", + c.toSQL(), + ) + }) + + t.Run("it renders charset and collation", func(t *testing.T) { + tb := Table{Name: "test", Charset: "rand", Collation: "random_io"} + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT) ENGINE=InnoDB DEFAULT CHARSET=rand COLLATE=random_io", + c.toSQL(), + ) + }) + + t.Run("it renders charset and manually add collation", func(t *testing.T) { + tb := Table{Name: "test", Charset: "utf8"} + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_unicode_ci", + c.toSQL(), + ) + }) + + t.Run("it renders collation and manually add charset", func(t *testing.T) { + tb := Table{Name: "test", Collation: "utf8_general_ci"} + c := createTableCommand{tb} + + assert.Equal( + t, + "CREATE TABLE `test` (`id` bigint(20) unsigned NOT NULL AUTO_INCREMENT) ENGINE=InnoDB DEFAULT CHARSET=utf8 COLLATE=utf8_general_ci", + c.toSQL(), + ) + }) + + t.Run("it renders all together", func(t *testing.T) { + tb := Table{ + Name: "test", + columns: []column{ + {"test", testColumnType("random thing")}, + {"random", testColumnType("another thing")}, + }, + indexes: []key{ + {name: "idx_rand", columns: []string{"id"}}, + {columns: []string{"id", "name"}}, + }, + foreigns: []foreign{ + {key: "idx_foreign", column: "test_id", reference: "id", on: "tests"}, + {key: "foreign_idx", column: "random_id", reference: "id", on: "randoms"}, + }, + Engine: "MyISAM", + Charset: "rand", + Collation: "random_io", + } + c := createTableCommand{tb} + + assert.Equal( + t, + strings.Join([]string{ + "CREATE TABLE `test` (", + "`test` random thing, `random` another thing, ", + "KEY `idx_rand` (`id`), KEY (`id`, `name`), ", + "CONSTRAINT `idx_foreign` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`), ", + "CONSTRAINT `foreign_idx` FOREIGN KEY (`random_id`) REFERENCES `randoms` (`id`)", + ") ENGINE=MyISAM DEFAULT CHARSET=rand COLLATE=random_io", + }, ""), + c.toSQL(), + ) + }) +} + +func TestDropTableCommand(t *testing.T) { + t.Run("it drops table", func(t *testing.T) { + c := dropTableCommand{"test", false, ""} + assert.Equal(t, "DROP TABLE `test`", c.toSQL()) + }) + + t.Run("it drops table if exists", func(t *testing.T) { + c := dropTableCommand{"test", true, ""} + assert.Equal(t, "DROP TABLE IF EXISTS `test`", c.toSQL()) + }) + + t.Run("it drops table with cascade flag", func(t *testing.T) { + c := dropTableCommand{"test", false, "cascade"} + assert.Equal(t, "DROP TABLE `test` CASCADE", c.toSQL()) + }) + + t.Run("it drops table if exists with restrict flag", func(t *testing.T) { + c := dropTableCommand{"test", true, "restrict"} + assert.Equal(t, "DROP TABLE IF EXISTS `test` RESTRICT", c.toSQL()) + }) +} + +func TestRenameTableCommand(t *testing.T) { + c := renameTableCommand{"from", "to"} + + assert.Equal(t, "RENAME TABLE `from` TO `to`", c.toSQL()) +} + +func TestAlterTableCommand(t *testing.T) { + t.Run("it returns an empty command if table name is missing", func(t *testing.T) { + c := alterTableCommand{pool: TableCommands{testCommand("test")}} + + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty command if pool is empty", func(t *testing.T) { + c := alterTableCommand{name: "test"} + + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it renders command with one alter sub-command", func(t *testing.T) { + c := alterTableCommand{name: "test", pool: TableCommands{testCommand("test")}} + + assert.Equal(t, "ALTER TABLE `test` Do action on test", c.toSQL()) + }) + + t.Run("it renders command with multiple alter sub-command", func(t *testing.T) { + c := alterTableCommand{ + name: "test", + pool: TableCommands{testCommand("test"), testCommand("bang")}, + } + + assert.Equal(t, "ALTER TABLE `test` Do action on test, Do action on bang", c.toSQL()) + }) +} diff --git a/schema_test.go b/schema_test.go new file mode 100644 index 0000000..ac77f0e --- /dev/null +++ b/schema_test.go @@ -0,0 +1,69 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSchemaCreateTable(t *testing.T) { + assert := assert.New(t) + + s := Schema{} + assert.Len(s.pool, 0) + + tb := Table{Name: "test"} + s.CreateTable(tb) + + assert.Len(s.pool, 1) + assert.Equal(createTableCommand{tb}, s.pool[0]) +} + +func TestSchemaDropTable(t *testing.T) { + assert := assert.New(t) + + s := Schema{} + assert.Len(s.pool, 0) + + s.DropTable("test", false, "") + + assert.Len(s.pool, 1) + assert.Equal(dropTableCommand{"test", false, ""}, s.pool[0]) +} + +func TestSchemaRenameTable(t *testing.T) { + assert := assert.New(t) + + s := Schema{} + assert.Len(s.pool, 0) + + s.RenameTable("from", "to") + + assert.Len(s.pool, 1) + assert.Equal(renameTableCommand{"from", "to"}, s.pool[0]) +} + +func TestSchemaAlterTable(t *testing.T) { + assert := assert.New(t) + + s := Schema{} + assert.Len(s.pool, 0) + + s.AlterTable("table", TableCommands{}) + + assert.Len(s.pool, 1) + assert.Equal(alterTableCommand{"table", TableCommands{}}, s.pool[0]) +} + +func TestSchemaCustomCommand(t *testing.T) { + assert := assert.New(t) + c := testDummyCommand("DROP PROCEDURE abc") + + s := Schema{} + assert.Len(s.pool, 0) + + s.CustomCommand(c) + + assert.Len(s.pool, 1) + assert.Equal(c, s.pool[0]) +} diff --git a/table.go b/table.go new file mode 100644 index 0000000..4f45c2e --- /dev/null +++ b/table.go @@ -0,0 +1,139 @@ +// Package migrator represents MySQL database migrator +package migrator + +import "strings" + +// Table is an entity to create table +// +// Name table name +// Engine default: InnoDB +// Charset default: utf8mb4 or first part of collation (if set) +// Collation default: utf8mb4_unicode_ci or charset with `_unicode_ci` suffix +// Comment optional comment on table +type Table struct { + Name string + columns columns + indexes keys + foreigns foreigns + Engine string + Charset string + Collation string + Comment string +} + +// Column adds column to the table +func (t *Table) Column(name string, c columnType) { + t.columns = append(t.columns, column{field: name, definition: c}) +} + +// ID adds bigint `id` column that is primary key +func (t *Table) ID(name string) { + t.Column(name, Integer{ + Prefix: "big", + Unsigned: true, + Autoincrement: true, + }) + t.Primary(name) +} + +// UniqueID adds unique id column (represented as UUID) that is primary key +func (t *Table) UniqueID(name string) { + t.UUID(name, "(UUID())", false) + t.Primary(name) +} + +// Boolean represented in DB as tinyint +func (t *Table) Boolean(name string, def string) { + // tinyint(1) + t.Column(name, Integer{ + Prefix: "tiny", + Unsigned: true, + Precision: 1, + Default: def, + }) +} + +// UUID adds char(36) column +func (t *Table) UUID(name string, def string, nullable bool) { + // char(36) + t.Column(name, String{ + Fixed: true, + Precision: 36, + Default: def, + Nullable: nullable, + }) +} + +// Timestamps adds default timestamps: `created_at` and `updated_at` +func (t *Table) Timestamps() { + // created_at not null default CURRENT_TIMESTAMP + t.Column("created_at", Timable{ + Type: "timestamp", + Default: "CURRENT_TIMESTAMP", + }) + // updated_at not null default CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP + t.Column("updated_at", Timable{ + Type: "timestamp", + Default: "CURRENT_TIMESTAMP", + OnUpdate: "CURRENT_TIMESTAMP", + }) +} + +// Primary adds primary key +func (t *Table) Primary(columns ...string) { + if len(columns) == 0 { + return + } + + t.indexes = append(t.indexes, key{ + typ: "primary", + columns: columns, + }) +} + +// Unique adds unique key on selected columns +func (t *Table) Unique(columns ...string) { + if len(columns) == 0 { + return + } + + t.indexes = append(t.indexes, key{ + name: t.buildUniqueKeyName(columns...), + typ: "unique", + columns: columns, + }) +} + +// Index adds index (key) on selected columns +func (t *Table) Index(name string, columns ...string) { + if len(columns) == 0 { + return + } + + t.indexes = append(t.indexes, key{name: name, columns: columns}) +} + +// Foreign adds foreign key constraints +func (t *Table) Foreign(column string, reference string, on string, onUpdate string, onDelete string) { + name := t.buildForeignKeyName(column) + t.indexes = append(t.indexes, key{ + name: name, + columns: []string{column}, + }) + t.foreigns = append(t.foreigns, foreign{ + key: name, + column: column, + reference: reference, + on: on, + onUpdate: onUpdate, + onDelete: onDelete, + }) +} + +func (t *Table) buildUniqueKeyName(columns ...string) string { + return t.Name + "_" + strings.Join(columns, "_") + "_unique" +} + +func (t *Table) buildForeignKeyName(column string) string { + return t.Name + "_" + column + "_foreign" +} diff --git a/table_command.go b/table_command.go new file mode 100644 index 0000000..0edb60f --- /dev/null +++ b/table_command.go @@ -0,0 +1,206 @@ +// Package migrator represents MySQL database migrator +package migrator + +import ( + "fmt" + "strings" +) + +// TableCommands is pool of commands to be executed on table +// https://dev.mysql.com/doc/refman/8.0/en/alter-table.html +type TableCommands []command + +func (tc TableCommands) toSQL() string { + rows := []string{} + + for _, c := range tc { + rows = append(rows, c.toSQL()) + } + + return strings.Join(rows, ", ") +} + +// AddColumnCommand is a command to add column to the table +type AddColumnCommand struct { + Name string + Column columnType + After string + First bool +} + +func (c AddColumnCommand) toSQL() string { + if c.Column == nil { + return "" + } + + definition := c.Column.buildRow() + if c.Name == "" || definition == "" { + return "" + } + + sql := "ADD COLUMN `" + c.Name + "` " + definition + + if c.After != "" { + sql += " AFTER " + c.After + } else if c.First { + sql += " FIRST" + } + + return sql +} + +// RenameColumnCommand is a command to rename column in the table +// Warning ⚠️ BC incompatible +// Info ℹ️ extensions for Oracle compatibility +type RenameColumnCommand struct { + Old string + New string +} + +func (c RenameColumnCommand) toSQL() string { + if c.Old == "" || c.New == "" { + return "" + } + + return fmt.Sprintf("RENAME COLUMN `%s` TO `%s`", c.Old, c.New) +} + +// ModifyColumnCommand is a command to modify column type +// Warning ⚠️ BC incompatible +// Info ℹ️ extensions for Oracle compatibility +type ModifyColumnCommand struct { + Name string + Column columnType +} + +func (c ModifyColumnCommand) toSQL() string { + if c.Column == nil { + return "" + } + + definition := c.Column.buildRow() + if c.Name == "" || definition == "" { + return "" + } + + return fmt.Sprintf("MODIFY `%s` %s", c.Name, definition) +} + +// ChangeColumnCommand is a default command to change column +// Warning ⚠️ BC incompatible +type ChangeColumnCommand struct { + From string + To string + Column columnType +} + +func (c ChangeColumnCommand) toSQL() string { + if c.Column == nil { + return "" + } + + definition := c.Column.buildRow() + if c.From == "" || c.To == "" || definition == "" { + return "" + } + + return fmt.Sprintf("CHANGE `%s` `%s` %s", c.From, c.To, c.Column.buildRow()) +} + +// DropColumnCommand is a command to drop column from the table +// Warning ⚠️ BC incompatible +type DropColumnCommand string + +// campatible with Oracle +func (c DropColumnCommand) toSQL() string { + if c == "" { + return "" + } + + return fmt.Sprintf("DROP COLUMN `%s`", c) +} + +// AddIndexCommand adds a key to the table +type AddIndexCommand struct { + Name string + Columns []string +} + +func (c AddIndexCommand) toSQL() string { + if c.Name == "" || len(c.Columns) == 0 { + return "" + } + + return fmt.Sprintf("ADD KEY `%s` (`%s`)", c.Name, strings.Join(c.Columns, "`, `")) +} + +// DropIndexCommand removes the key from the table +type DropIndexCommand string + +func (c DropIndexCommand) toSQL() string { + if c == "" { + return "" + } + + return fmt.Sprintf("DROP KEY `%s`", c) +} + +// AddForeignCommand adds the foreign key contraint to the table +type AddForeignCommand struct { + Foreign foreign +} + +func (c AddForeignCommand) toSQL() string { + if c.Foreign.render() == "" { + return "" + } + + return "ADD " + c.Foreign.render() +} + +// DropForeignCommand is a command to remove foreign key contraint +type DropForeignCommand string + +func (c DropForeignCommand) toSQL() string { + if c == "" { + return "" + } + + return fmt.Sprintf("DROP FOREIGN KEY `%s`", c) +} + +// AddUniqueIndexCommand is a command to add unique key to the table on some columns +type AddUniqueIndexCommand struct { + Key string + Columns []string +} + +func (c AddUniqueIndexCommand) toSQL() string { + if c.Key == "" || len(c.Columns) == 0 { + return "" + } + + return fmt.Sprintf("ADD UNIQUE KEY `%s` (`%s`)", c.Key, strings.Join(c.Columns, "`, `")) +} + +// AddPrimaryIndexCommand is a command to add a primary key +type AddPrimaryIndexCommand string + +func (c AddPrimaryIndexCommand) toSQL() string { + if c == "" { + return "" + } + + return fmt.Sprintf("ADD PRIMARY KEY (`%s`)", c) +} + +// DropPrimaryIndexCommand is a command to remove primary key from the table +type DropPrimaryIndexCommand struct{} + +func (c DropPrimaryIndexCommand) toSQL() string { + return "DROP PRIMARY KEY" +} + +// ADD {FULLTEXT | SPATIAL} [INDEX | KEY] [index_name] (key_part,...) [index_option] ... +// DROP {CHECK | CONSTRAINT} symbol +// RENAME {INDEX | KEY} old_index_name TO new_index_name diff --git a/table_command_test.go b/table_command_test.go new file mode 100644 index 0000000..ae63677 --- /dev/null +++ b/table_command_test.go @@ -0,0 +1,221 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTableCommands(t *testing.T) { + t.Run("it returns empty on empty commands list", func(t *testing.T) { + c := TableCommands{} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it renders row from one command", func(t *testing.T) { + c := TableCommands{testCommand("test")} + assert.Equal(t, "Do action on test", c.toSQL()) + }) + + t.Run("it renders row from multiple commands", func(t *testing.T) { + c := TableCommands{testCommand("test"), testCommand("bang")} + assert.Equal(t, "Do action on test, Do action on bang", c.toSQL()) + }) +} + +func TestAddColumnCommand(t *testing.T) { + t.Run("it returns an empty string if column definition missing", func(t *testing.T) { + c := AddColumnCommand{Name: "tests"} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column name missing", func(t *testing.T) { + c := AddColumnCommand{Column: testColumnType("test")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column definition empty", func(t *testing.T) { + c := AddColumnCommand{Name: "tests", Column: testColumnType("")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns base row", func(t *testing.T) { + c := AddColumnCommand{Name: "test_id", Column: testColumnType("definition")} + assert.Equal(t, "ADD COLUMN `test_id` definition", c.toSQL()) + }) + + t.Run("it returns row with after column", func(t *testing.T) { + c := AddColumnCommand{Name: "test_id", Column: testColumnType("definition"), After: "id"} + assert.Equal(t, "ADD COLUMN `test_id` definition AFTER id", c.toSQL()) + }) + + t.Run("it returns row with first flag", func(t *testing.T) { + c := AddColumnCommand{Name: "test_id", Column: testColumnType("definition"), First: true} + assert.Equal(t, "ADD COLUMN `test_id` definition FIRST", c.toSQL()) + }) +} + +func TestRenameColumnCommand(t *testing.T) { + t.Run("it returns an empty string if old name missing", func(t *testing.T) { + c := RenameColumnCommand{New: "test"} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if new name missing", func(t *testing.T) { + c := RenameColumnCommand{Old: "test"} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := RenameColumnCommand{Old: "from", New: "to"} + assert.Equal(t, "RENAME COLUMN `from` TO `to`", c.toSQL()) + }) +} + +func TestModifyColumnCommand(t *testing.T) { + t.Run("it returns an empty string if column definition missing", func(t *testing.T) { + c := ModifyColumnCommand{Name: "tests"} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column name missing", func(t *testing.T) { + c := ModifyColumnCommand{Column: testColumnType("test")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column definition empty", func(t *testing.T) { + c := ModifyColumnCommand{Name: "tests", Column: testColumnType("")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := ModifyColumnCommand{Name: "test_id", Column: testColumnType("definition")} + assert.Equal(t, "MODIFY `test_id` definition", c.toSQL()) + }) +} + +func TestChangeColumnCommand(t *testing.T) { + t.Run("it returns an empty string if column definition missing", func(t *testing.T) { + c := ChangeColumnCommand{From: "tests", To: "something"} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column from name missing", func(t *testing.T) { + c := ChangeColumnCommand{To: "something", Column: testColumnType("test")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column to name missing", func(t *testing.T) { + c := ChangeColumnCommand{From: "tests", Column: testColumnType("test")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if column definition empty", func(t *testing.T) { + c := ChangeColumnCommand{From: "tests", To: "something", Column: testColumnType("")} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := ChangeColumnCommand{From: "tests", To: "something", Column: testColumnType("definition")} + assert.Equal(t, "CHANGE `tests` `something` definition", c.toSQL()) + }) +} + +func TestDropColumnCommand(t *testing.T) { + t.Run("it returns an empty string if column name missing", func(t *testing.T) { + c := DropColumnCommand("") + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := DropColumnCommand("test_id") + assert.Equal(t, "DROP COLUMN `test_id`", c.toSQL()) + }) +} + +func TestAddIndexCommand(t *testing.T) { + t.Run("it returns an empty string if index name missing", func(t *testing.T) { + c := AddIndexCommand{Columns: []string{"test"}} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if columns list empty", func(t *testing.T) { + c := AddIndexCommand{Name: "test", Columns: []string{}} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := AddIndexCommand{Name: "test_idx", Columns: []string{"test"}} + assert.Equal(t, "ADD KEY `test_idx` (`test`)", c.toSQL()) + }) +} + +func TestDropIndexCommand(t *testing.T) { + t.Run("it returns an empty string if index name missing", func(t *testing.T) { + c := DropIndexCommand("") + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := DropIndexCommand("test_idx") + assert.Equal(t, "DROP KEY `test_idx`", c.toSQL()) + }) +} + +func TestAddForeignCommand(t *testing.T) { + t.Run("it returns an empty string on missing foreign key", func(t *testing.T) { + c := AddForeignCommand{} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it builds a proper row", func(t *testing.T) { + c := AddForeignCommand{foreign{key: "idx_foreign", column: "test_id", reference: "id", on: "tests"}} + assert.Equal(t, "ADD CONSTRAINT `idx_foreign` FOREIGN KEY (`test_id`) REFERENCES `tests` (`id`)", c.toSQL()) + }) +} + +func TestDropForeignCommand(t *testing.T) { + t.Run("it returns an empty string if index name missing", func(t *testing.T) { + c := DropForeignCommand("") + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := DropForeignCommand("test_idx") + assert.Equal(t, "DROP FOREIGN KEY `test_idx`", c.toSQL()) + }) +} + +func TestAddUniqueIndexCommand(t *testing.T) { + t.Run("it returns an empty string if index name missing", func(t *testing.T) { + c := AddUniqueIndexCommand{Columns: []string{"test"}} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns an empty string if columns list empty", func(t *testing.T) { + c := AddUniqueIndexCommand{Key: "test", Columns: []string{}} + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := AddUniqueIndexCommand{Key: "test_idx", Columns: []string{"test"}} + assert.Equal(t, "ADD UNIQUE KEY `test_idx` (`test`)", c.toSQL()) + }) +} + +func TestAddPrimaryIndexCommand(t *testing.T) { + t.Run("it returns an empty string if index name missing", func(t *testing.T) { + c := AddPrimaryIndexCommand("") + assert.Equal(t, "", c.toSQL()) + }) + + t.Run("it returns a proper row", func(t *testing.T) { + c := AddPrimaryIndexCommand("test_idx") + assert.Equal(t, "ADD PRIMARY KEY (`test_idx`)", c.toSQL()) + }) +} + +func TestDropPrimaryIndexCommand(t *testing.T) { + c := DropPrimaryIndexCommand{} + assert.Equal(t, "DROP PRIMARY KEY", c.toSQL()) +} diff --git a/table_test.go b/table_test.go new file mode 100644 index 0000000..8fd9fc9 --- /dev/null +++ b/table_test.go @@ -0,0 +1,207 @@ +package migrator + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTableColumns(t *testing.T) { + c := testColumnType("test") + + assert := assert.New(t) + + table := Table{} + assert.Len(table.columns, 0) + + table.Column("test", c) + + assert.Len(table.columns, 1) + assert.Equal(columns{column{"test", c}}, table.columns) +} + +func TestIDColumn(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.columns) + assert.Len(table.indexes, 0) + + table.ID("id") + + assert.Len(table.columns, 1) + assert.Equal("id", table.columns[0].field) + assert.Equal(Integer{Prefix: "big", Unsigned: true, Autoincrement: true}, table.columns[0].definition) + assert.Len(table.indexes, 1) + assert.Equal(key{typ: "primary", columns: []string{"id"}}, table.indexes[0]) +} + +func TestUniqueIDColumn(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.columns) + assert.Len(table.indexes, 0) + + table.UniqueID("id") + + assert.Len(table.columns, 1) + assert.Equal("id", table.columns[0].field) + assert.Equal(String{Default: "(UUID())", Fixed: true, Precision: 36}, table.columns[0].definition) + assert.Len(table.indexes, 1) + assert.Equal(key{typ: "primary", columns: []string{"id"}}, table.indexes[0]) +} + +func TestBooleanColumn(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.columns) + + table.Boolean("flag", "1") + + assert.Len(table.columns, 1) + assert.Equal("flag", table.columns[0].field) + assert.Equal(Integer{Prefix: "tiny", Default: "1", Unsigned: true, Precision: 1}, table.columns[0].definition) +} + +func TestUUIDColumn(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.columns) + + table.UUID("uuid", "1111", true) + + assert.Len(table.columns, 1) + assert.Equal("uuid", table.columns[0].field) + assert.Equal(String{Default: "1111", Fixed: true, Precision: 36, Nullable: true}, table.columns[0].definition) +} + +func TestTimestampsColumn(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.columns) + + table.Timestamps() + + assert.Len(table.columns, 2) + assert.Equal("created_at", table.columns[0].field) + assert.Equal(Timable{Type: "timestamp", Default: "CURRENT_TIMESTAMP"}, table.columns[0].definition) + assert.Equal("updated_at", table.columns[1].field) + assert.Equal(Timable{Type: "timestamp", Default: "CURRENT_TIMESTAMP", OnUpdate: "CURRENT_TIMESTAMP"}, table.columns[1].definition) +} + +func TestTablePrimaryIndex(t *testing.T) { + t.Run("it skips adding key on empty columns list", func(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.indexes) + + table.Primary() + + assert.Nil(table.indexes) + }) + + t.Run("it adds primary key", func(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.indexes) + + table.Primary("id", "name") + + assert.Len(table.indexes, 1) + assert.Equal(key{typ: "primary", columns: []string{"id", "name"}}, table.indexes[0]) + }) +} + +func TestTableUniqueIndex(t *testing.T) { + t.Run("it skips adding key on empty columns list", func(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.indexes) + + table.Unique() + + assert.Nil(table.indexes) + }) + + t.Run("it adds unique key", func(t *testing.T) { + assert := assert.New(t) + table := Table{Name: "table"} + + assert.Nil(table.indexes) + + table.Unique("id", "name") + + assert.Len(table.indexes, 1) + assert.Equal(key{name: "table_id_name_unique", typ: "unique", columns: []string{"id", "name"}}, table.indexes[0]) + }) +} + +func TestTableIndex(t *testing.T) { + t.Run("it skips adding key on empty columns list", func(t *testing.T) { + assert := assert.New(t) + table := Table{} + + assert.Nil(table.indexes) + + table.Index("test") + + assert.Nil(table.indexes) + }) + + t.Run("it adds unique key", func(t *testing.T) { + assert := assert.New(t) + table := Table{Name: "table"} + + assert.Nil(table.indexes) + + table.Index("test_idx", "id", "name") + + assert.Len(table.indexes, 1) + assert.Equal(key{name: "test_idx", columns: []string{"id", "name"}}, table.indexes[0]) + }) +} + +func TestTableForeignIndex(t *testing.T) { + assert := assert.New(t) + table := Table{Name: "table"} + + assert.Nil(table.indexes) + assert.Nil(table.foreigns) + + table.Foreign("test_id", "id", "tests", "set null", "cascade") + + assert.Len(table.indexes, 1) + assert.Equal(key{name: "table_test_id_foreign", columns: []string{"test_id"}}, table.indexes[0]) + assert.Len(table.foreigns, 1) + assert.Equal( + foreign{key: "table_test_id_foreign", column: "test_id", reference: "id", on: "tests", onUpdate: "set null", onDelete: "cascade"}, + table.foreigns[0], + ) +} + +func TestBuildUniqueIndexName(t *testing.T) { + t.Run("It builds name from one column", func(t *testing.T) { + table := Table{Name: "table"} + + assert.Equal(t, "table_test_unique", table.buildUniqueKeyName("test")) + }) + + t.Run("it builds name from multiple columns", func(t *testing.T) { + table := Table{Name: "table"} + + assert.Equal(t, "table_test_again_unique", table.buildUniqueKeyName("test", "again")) + }) +} + +func TestBuildForeignIndexName(t *testing.T) { + table := Table{Name: "table"} + + assert.Equal(t, "table_test_foreign", table.buildForeignKeyName("test")) +}