diff --git a/cmd/create.go b/cmd/create.go index 9687021..6cf280a 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -1,6 +1,7 @@ package cmd import ( + "bytes" "fmt" "path" "path/filepath" @@ -35,12 +36,13 @@ var createCmd = &cobra.Command{ inUp := "" if cmdCtx.Create.Dump { - dump, err := am.DumpSchema() + buffer := &bytes.Buffer{} + err := am.DumpSchema(buffer, true) if err != nil { return err } - inUp += fmt.Sprintf("s.Exec(`%s`)\n", dump) + inUp += fmt.Sprintf("s.Exec(`%s`)\n", buffer.String()) cmdCtx.Create.Type = "classic" } @@ -107,9 +109,6 @@ func init() { createCmd.Flags().BoolVarP(&cmdCtx.Create.Dump, "dump", "d", false, "dump with pg_dump the current schema and add it to the current migration") - createCmd.Flags().StringVarP(&cmdCtx.Create.DumpSchema, "dump-schema", "s", "public", - "the schema to dump if --dump is set") - createCmd.Flags().StringVar(&cmdCtx.Create.SQLSeparator, "sql-separator", "-- migrate:down", "the separator to split the up and down part of the migration") diff --git a/cmd/migration.go b/cmd/migration.go index 84c6610..f8a4383 100644 --- a/cmd/migration.go +++ b/cmd/migration.go @@ -15,7 +15,11 @@ var migrateCmd = &cobra.Command{ return err } - return am.ExecuteMain(amigo.MainArgMigrate) + if err := am.ExecuteMain(amigo.MainArgMigrate); err != nil { + return err + } + + return nil }), } @@ -43,10 +47,15 @@ func init() { cmd.Flags().BoolVar(&m.ContinueOnError, "continue-on-error", false, "Will not rollback the migration if an error occurs") cmd.Flags().DurationVar(&m.Timeout, "timeout", amigoctx.DefaultTimeout, "The timeout for the migration") + cmd.Flags().BoolVarP(&m.DumpSchemaAfter, "dump-schema-after", "d", false, + "Dump schema after migrate/rollback (not compatible with --use-schema-dump)") } registerBase(migrateCmd, cmdCtx.Migration) + migrateCmd.Flags().BoolVar(&cmdCtx.Migration.UseSchemaDump, "use-schema-dump", false, + "Use the schema file to apply the migration (for fresh install without any migration)") registerBase(rollbackCmd, cmdCtx.Migration) rollbackCmd.Flags().IntVar(&cmdCtx.Migration.Steps, "steps", 1, "The number of steps to rollback") + } diff --git a/cmd/root.go b/cmd/root.go index be1e5ca..a587cd9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -27,7 +27,7 @@ var rootCmd = &cobra.Command{ Long: `Basic usage: First you need to create a main folder with amigo init: - will create a folder in migrations/db with a context file inside to not have to pass the dsn every time. + will create a folder in db with a context file inside to not have to pass the dsn every time. Postgres: $ amigo context --dsn "postgres://user:password@host:port/dbname?sslmode=disable" @@ -92,9 +92,15 @@ func init() { rootCmd.PersistentFlags().BoolVar(&cmdCtx.ShowSQLSyntaxHighlighting, "sql-syntax-highlighting", true, "Print SQL queries with syntax highlighting") - rootCmd.Flags().StringVar(&cmdCtx.PGDumpPath, "pg-dump-path", amigoctx.DefaultPGDumpPath, + rootCmd.PersistentFlags().StringVar(&cmdCtx.SchemaOutPath, "schema-out-path", amigoctx.DefaultSchemaOutPath, + "File path of the schema dump if any") + + rootCmd.PersistentFlags().StringVar(&cmdCtx.PGDumpPath, "pg-dump-path", amigoctx.DefaultPGDumpPath, "Path to the pg_dump command if --dump is set") + rootCmd.PersistentFlags().StringVar(&cmdCtx.SchemaDBDumpSchema, "schema-db-dump-schema", + amigoctx.DefaultDBDumpSchema, "Schema to use when dumping schema") + rootCmd.PersistentFlags().BoolVar(&cmdCtx.Debug, "debug", false, "Print debug information") initConfig() } @@ -122,7 +128,8 @@ func initConfig() { _ = viper.BindPFlag("pg-dump-path", createCmd.Flags().Lookup("pg-dump-path")) _ = viper.BindPFlag("sql", rootCmd.PersistentFlags().Lookup("sql")) _ = viper.BindPFlag("sql-syntax-highlighting", rootCmd.PersistentFlags().Lookup("sql-syntax-highlighting")) - + _ = viper.BindPFlag("schema-out-path", rootCmd.Flags().Lookup("schema-out-path")) + _ = viper.BindPFlag("schema-db-dump-schema", rootCmd.Flags().Lookup("schema-db-dump-schema")) _ = viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug")) viper.SetConfigFile(filepath.Join(cmdCtx.AmigoFolderPath, contextFileName)) @@ -171,12 +178,20 @@ func initConfig() { if viper.IsSet("debug") { cmdCtx.Debug = viper.GetBool("debug") } + + if viper.IsSet("schema-out-path") { + cmdCtx.SchemaOutPath = viper.GetString("schema-out-path") + } + + if viper.IsSet("schema-db-dump-schema") { + cmdCtx.SchemaDBDumpSchema = viper.GetString("schema-db-dump-schema") + } } func wrapCobraFunc(f func(cmd *cobra.Command, am amigo.Amigo, args []string) error) func(cmd *cobra.Command, args []string) { return func(cmd *cobra.Command, args []string) { am := amigo.NewAmigo(cmdCtx) - am.SetupSlog(os.Stdout) + am.SetupSlog(os.Stdout, nil) if err := f(cmd, am, args); err != nil { logger.Error(events.MessageEvent{Message: err.Error()}) diff --git a/cmd/schema.go b/cmd/schema.go new file mode 100644 index 0000000..4befaf2 --- /dev/null +++ b/cmd/schema.go @@ -0,0 +1,48 @@ +package cmd + +import ( + "fmt" + "path" + + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/utils" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" + "github.com/spf13/cobra" +) + +var schemaCmd = &cobra.Command{ + Use: "schema", + Short: "Dump the schema of the database using appropriate tool", + Long: `Dump the schema of the database using appropriate tool. +Supported databases: + - postgres with pg_dump`, + Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { + if err := cmdCtx.ValidateDSN(); err != nil { + return err + } + + return dumpSchema(am) + }), +} + +func dumpSchema(am amigo.Amigo) error { + file, err := utils.CreateOrOpenFile(cmdCtx.SchemaOutPath) + if err != nil { + return fmt.Errorf("unable to open/create file: %w", err) + } + + defer file.Close() + + err = am.DumpSchema(file, false) + if err != nil { + return fmt.Errorf("unable to dump schema: %w", err) + } + + logger.Info(events.FileModifiedEvent{FileName: path.Join(cmdCtx.SchemaOutPath)}) + return nil +} + +func init() { + rootCmd.AddCommand(schemaCmd) +} diff --git a/docs/docs/02-quick-start/02-initialize.md b/docs/docs/02-quick-start/02-initialize.md index af97bec..a7e83b5 100644 --- a/docs/docs/02-quick-start/02-initialize.md +++ b/docs/docs/02-quick-start/02-initialize.md @@ -2,7 +2,7 @@ To start using mig, you need to initialize it. This process creates few things: - A `migrations` folder where you will write your migrations. -- A `migrations/db` folder where mig stores its configuration and the main file to run migrations. +- A `db/migrations` folder where mig stores its configuration and the main file to run migrations. - A migration file to setup the table that will store the migration versions. To initialize mig, run the following command: diff --git a/pkg/amigo/amigo.go b/pkg/amigo/amigo.go index 3f88e7c..c1bc1c9 100644 --- a/pkg/amigo/amigo.go +++ b/pkg/amigo/amigo.go @@ -1,24 +1,8 @@ package amigo import ( - "database/sql" - "fmt" - "io" - "os" - "path" - "path/filepath" - "regexp" - "sort" - "strings" - "time" - "github.com/alexisvisco/amigo/pkg/amigoctx" - "github.com/alexisvisco/amigo/pkg/schema" - "github.com/alexisvisco/amigo/pkg/templates" "github.com/alexisvisco/amigo/pkg/types" - "github.com/alexisvisco/amigo/pkg/utils" - "github.com/alexisvisco/amigo/pkg/utils/cmdexec" - "github.com/gobuffalo/flect" ) type Amigo struct { @@ -33,265 +17,3 @@ func NewAmigo(ctx *amigoctx.Context) Amigo { Driver: types.GetDriver(ctx.DSN), } } - -// DumpSchema of the database and write it to the writer -func (a Amigo) DumpSchema() (string, error) { - db, err := schema.ExtractCredentials(a.ctx.GetRealDSN()) - if err != nil { - return "", err - } - - ignoreTableName := a.ctx.SchemaVersionTable - if strings.Contains(ignoreTableName, ".") { - ignoreTableName = strings.Split(ignoreTableName, ".")[1] - } - - args := []string{ - a.ctx.PGDumpPath, - "-d", db.DB, - "-h", db.Host, - "-U", db.User, - "-p", db.Port, - "-n", a.ctx.Create.DumpSchema, - "-s", - "-x", - "-O", - "-T", ignoreTableName, - "--no-comments", - "--no-owner", - "--no-privileges", - "--no-tablespaces", - "--no-security-labels", - } - - env := map[string]string{"PGPASSWORD": db.Pass} - - stdout, stderr, err := cmdexec.Exec(a.ctx.ShellPath, []string{"-c", strings.Join(args, " ")}, env) - if err != nil { - return "", fmt.Errorf("unable to dump database: %w\n%s", err, stderr) - } - - // replace all regexp listed below to empty string - regexpReplace := []string{ - `-- --- Name: .*; Type: SCHEMA; Schema: -; Owner: - --- - -CREATE SCHEMA .*; -`, - } - - for _, r := range regexpReplace { - stdout = regexp.MustCompile(r).ReplaceAllString(stdout, "") - } - - return stdout, nil -} - -// GenerateMainFile generate the main.go file in the amigo folder -func (a Amigo) GenerateMainFile(writer io.Writer) error { - name, err := utils.GetModuleName() - if err != nil { - return fmt.Errorf("unable to get module name: %w", err) - } - - packagePath := path.Join(name, a.ctx.MigrationFolder) - - template, err := templates.GetMainTemplate(templates.MainData{ - PackagePath: packagePath, - DriverPath: a.Driver.PackagePath(), - DriverName: a.Driver.String(), - }) - - if err != nil { - return fmt.Errorf("unable to get main template: %w", err) - } - - _, err = writer.Write([]byte(template)) - if err != nil { - return fmt.Errorf("unable to write main file: %w", err) - } - - return nil -} - -type GenerateMigrationFileParams struct { - Name string - Up string - Down string - Change string - Type types.MigrationFileType - Now time.Time - UseSchemaImport bool - UseFmtImport bool - Writer io.Writer -} - -// GenerateMigrationFile generate a migration file in the migrations folder -func (a Amigo) GenerateMigrationFile(params *GenerateMigrationFileParams) error { - - structName := utils.MigrationStructName(params.Now, params.Name) - - orDefault := func(s string) string { - if s == "" { - return "// TODO: implement the migration" - } - return s - } - - fileContent, err := templates.GetMigrationTemplate(templates.MigrationData{ - IsSQL: params.Type == types.MigrationFileTypeSQL, - Package: a.ctx.PackagePath, - StructName: structName, - Name: flect.Underscore(params.Name), - Type: params.Type, - InChange: orDefault(params.Change), - InUp: orDefault(params.Up), - InDown: orDefault(params.Down), - CreatedAt: params.Now.Format(time.RFC3339), - PackageDriverName: a.Driver.PackageName(), - PackageDriverPath: a.Driver.PackageSchemaPath(), - UseSchemaImport: params.UseSchemaImport, - UseFmtImport: params.UseFmtImport, - }) - - if err != nil { - return fmt.Errorf("unable to get migration template: %w", err) - } - - _, err = params.Writer.Write([]byte(fileContent)) - if err != nil { - return fmt.Errorf("unable to write migration file: %w", err) - } - - return nil -} - -// GenerateMigrationsFiles generate the migrations file in the migrations folder -// It's used to keep track of all migrations -func (a Amigo) GenerateMigrationsFiles(writer io.Writer) error { - migrationFiles, keys, err := a.GetMigrationFiles(true) - if err != nil { - return err - } - - var migrations []string - var mustImportSchemaPackage *string - for _, k := range keys { - if migrationFiles[k].IsSQL { - // schema.NewSQLMigration[*pg.Schema](sqlMigrationsFS, "20240602081806_drop_index.sql", "2024-06-02T10:18:06+02:00", "---- down:"), - line := fmt.Sprintf("schema.NewSQLMigration[%s](sqlMigrationsFS, \"%s\", \"%s\", \"%s\")", - a.Driver.StructName(), - migrationFiles[k].FulName, - k.Format(time.RFC3339), - a.ctx.Create.SQLSeparator, - ) - - migrations = append(migrations, line) - - if mustImportSchemaPackage == nil { - v := a.Driver.PackageSchemaPath() - mustImportSchemaPackage = &v - } - } else { - migrations = append(migrations, fmt.Sprintf("&%s{}", utils.MigrationStructName(k, migrationFiles[k].Name))) - - } - } - - content, err := templates.GetMigrationsTemplate(templates.MigrationsData{ - Package: a.ctx.PackagePath, - Migrations: migrations, - ImportSchemaPackage: mustImportSchemaPackage, - }) - - if err != nil { - return fmt.Errorf("unable to get migrations template: %w", err) - } - - _, err = writer.Write([]byte(content)) - if err != nil { - return fmt.Errorf("unable to write migrations file: %w", err) - } - - return nil -} - -// GetStatus return the state of the database -func (a Amigo) GetStatus(db *sql.DB) ([]string, error) { - rows, err := db.Query("SELECT version FROM " + a.ctx.SchemaVersionTable + " ORDER BY version desc") - if err != nil { - return nil, fmt.Errorf("unable to get state: %w", err) - } - - var state []string - for rows.Next() { - var id string - err := rows.Scan(&id) - if err != nil { - return nil, fmt.Errorf("unable to scan state: %w", err) - } - state = append(state, id) - } - - return state, nil -} - -type MigrationFile struct { - Name string - FulName string - IsSQL bool -} - -func (a Amigo) GetMigrationFiles(ascending bool) (map[time.Time]MigrationFile, []time.Time, error) { - migrationFiles := make(map[time.Time]MigrationFile) - - // get the list of structs by the file name - err := filepath.Walk(a.ctx.MigrationFolder, func(path string, info os.FileInfo, err error) error { - if err != nil { - return err - } - - if !info.IsDir() { - if utils.MigrationFileRegexp.MatchString(info.Name()) { - matches := utils.MigrationFileRegexp.FindStringSubmatch(info.Name()) - fileTime := matches[1] - migrationName := matches[2] - ext := matches[3] - - t, _ := time.Parse(utils.FormatTime, fileTime) - migrationFiles[t] = MigrationFile{Name: migrationName, IsSQL: ext == "sql", FulName: info.Name()} - } - } - - return nil - }) - if err != nil { - return nil, nil, fmt.Errorf("unable to walk through the migration folder: %w", err) - } - - // sort the files - var keys []time.Time - for k := range migrationFiles { - keys = append(keys, k) - } - - sort.Slice(keys, func(i, j int) bool { - if ascending { - return keys[i].Unix() < keys[j].Unix() - } else { - return keys[i].Unix() > keys[j].Unix() - } - }) - - return migrationFiles, keys, nil -} - -func (a Amigo) SkipMigrationFile(db *sql.DB) error { - _, err := db.Exec("INSERT INTO "+a.ctx.SchemaVersionTable+" (id) VALUES ($1)", a.ctx.Create.Version) - if err != nil { - return fmt.Errorf("unable to skip migration file: %w", err) - } - - return nil -} diff --git a/pkg/amigo/dump_schema.go b/pkg/amigo/dump_schema.go new file mode 100644 index 0000000..bca35e8 --- /dev/null +++ b/pkg/amigo/dump_schema.go @@ -0,0 +1,147 @@ +package amigo + +import ( + "fmt" + "io" + "regexp" + "sort" + "strings" + "time" + + "github.com/alexisvisco/amigo/pkg/schema" + "github.com/alexisvisco/amigo/pkg/utils/cmdexec" +) + +func (a Amigo) DumpSchema(writer io.Writer, ignoreSchemaVersionTable bool) error { + db, err := schema.ExtractCredentials(a.ctx.GetRealDSN()) + if err != nil { + return err + } + + ignoreTableName := a.ctx.SchemaVersionTable + if strings.Contains(ignoreTableName, ".") { + ignoreTableName = strings.Split(ignoreTableName, ".")[1] + } + + args := []string{ + a.ctx.PGDumpPath, + "-d", db.DB, + "-h", db.Host, + "-U", db.User, + "-p", db.Port, + "-n", a.ctx.SchemaDBDumpSchema, + "-s", + "-x", + "-O", + "--no-comments", + "--no-owner", + "--no-privileges", + "--no-tablespaces", + "--no-security-labels", + } + + if !ignoreSchemaVersionTable { + args = append(args, "-T="+ignoreTableName) + } + + env := map[string]string{"PGPASSWORD": db.Pass} + + stdout, stderr, err := cmdexec.Exec(a.ctx.ShellPath, []string{"-c", strings.Join(args, " ")}, env) + if err != nil { + return fmt.Errorf("unable to dump database: %w\n%s", err, stderr) + } + + // Generate extension statements + extensionsToAdd := autoDetectExtensions(stdout) + var extensionStatements strings.Builder + extensionStatements.WriteString("\n-- Create extensions if they don't exist\n") + for _, ext := range extensionsToAdd { + extensionStatements.WriteString(fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS \"%s\";\n", ext)) + } + extensionStatements.WriteString("\n") + + dateGenerated := fmt.Sprintf("-- Generated at: %s\n", time.Now().Format(time.RFC3339)) + + schemaPattern := fmt.Sprintf( + `(?s)(.*?)(?:--\s*\n--\s*Name:\s*%s;\s*Type:\s*SCHEMA.*?\n--\s*\n\s*CREATE\s+SCHEMA\s+%s;\s*\n)(.*)`, + regexp.QuoteMeta(a.ctx.SchemaDBDumpSchema), + regexp.QuoteMeta(a.ctx.SchemaDBDumpSchema), + ) + + dumpParts := regexp.MustCompile(schemaPattern).FindStringSubmatch(stdout) + if len(dumpParts) != 3 { + return fmt.Errorf("failed to parse schema dump: unexpected format") + } + + setSchemaPath := fmt.Sprintf("SET search_path TO %s;\n", a.ctx.SchemaDBDumpSchema) + + // Combine all parts with the proper ordering + result := dateGenerated + + dumpParts[1] + // Content before schema + setSchemaPath + + extensionStatements.String() + + dumpParts[2] // Content after schema + + writer.Write([]byte(result)) + + return nil +} + +func autoDetectExtensions(stdout string) []string { + extensions := make(map[string]bool) + + // Common patterns that indicate extension usage + patterns := map[string][]string{ + "uuid-ossp": { + `uuid_generate_v`, + `gen_random_uuid`, + }, + "hstore": { + `hstore_to_json`, + `hstore_to_array`, + }, + "postgis": { + `geometry_columns`, + `spatial_ref_sys`, + `ST_`, + }, + "pg_trgm": { + `similarity`, + `show_trgm`, + }, + "pgcrypto": { + `crypt(`, + `gen_salt`, + }, + "ltree": { + `ltree_gist`, + `lquery`, + }, + "citext": { + `citext_ops`, + }, + "tablefunc": { + `crosstab`, + `normal_rand`, + }, + } + + // Check each extension's patterns + for ext, searchPatterns := range patterns { + for _, pattern := range searchPatterns { + if strings.Contains(stdout, pattern) { + extensions[ext] = true + break + } + } + } + + // Convert map keys to sorted slice + result := make([]string, 0, len(extensions)) + for ext := range extensions { + result = append(result, ext) + } + sort.Strings(result) + + return result +} diff --git a/pkg/amigo/execute_main.go b/pkg/amigo/execute_main.go index f20b54c..3ab8294 100644 --- a/pkg/amigo/execute_main.go +++ b/pkg/amigo/execute_main.go @@ -1,14 +1,14 @@ package amigo import ( + "encoding/base64" + "encoding/json" "fmt" - "github.com/alexisvisco/amigo/pkg/utils" - "github.com/alexisvisco/amigo/pkg/utils/cmdexec" - "github.com/alexisvisco/amigo/pkg/utils/events" - "github.com/alexisvisco/amigo/pkg/utils/logger" "os" "path" "strings" + + "github.com/alexisvisco/amigo/pkg/utils/cmdexec" ) type MainArg string @@ -49,53 +49,15 @@ func (a Amigo) ExecuteMain(arg MainArg) error { return err } - args = []string{ - "./" + mainBinaryPath, - "-dsn", fmt.Sprintf(`"%s"`, a.ctx.DSN), - "-schema-version-table", a.ctx.SchemaVersionTable, - } - - if a.ctx.ShowSQL { - args = append(args, "-sql") - } - - if a.ctx.ShowSQLSyntaxHighlighting { - args = append(args, "-sql-syntax-highlighting") - } - - if a.ctx.Debug { - args = append(args, "-debug") - } - - if a.ctx.JSON { - args = append(args, "-json") - } - - if a.ctx.Debug { - logger.Debug(events.MessageEvent{Message: fmt.Sprintf("executing %s", args)}) + amigoCtxJson, err := json.Marshal(a.ctx) + if err != nil { + return fmt.Errorf("unable to marshal amigo context: %w", err) } - switch arg { - case MainArgMigrate, MainArgRollback: - if a.ctx.Migration.ContinueOnError { - args = append(args, "-continue-on-error") - } - - if a.ctx.Migration.DryRun { - args = append(args, "-dry-run") - } - - if a.ctx.Migration.Version != "" { - v, err := utils.ParseMigrationVersion(a.ctx.Migration.Version) - if err != nil { - return fmt.Errorf("unable to parse version: %w", err) - } - args = append(args, "-version", v) - } - - args = append(args, "-steps", fmt.Sprintf("%d", a.ctx.Migration.Steps)) - case MainArgSkipMigration: - args = append(args, "-version", a.ctx.Create.Version) + bas64Json := base64.StdEncoding.EncodeToString(amigoCtxJson) + args = []string{ + "./" + mainBinaryPath, + "-json", bas64Json, } args = append(args, string(arg)) diff --git a/pkg/amigo/gen_main_file.go b/pkg/amigo/gen_main_file.go new file mode 100644 index 0000000..88bdbfc --- /dev/null +++ b/pkg/amigo/gen_main_file.go @@ -0,0 +1,37 @@ +package amigo + +import ( + "fmt" + "io" + "path" + + "github.com/alexisvisco/amigo/pkg/templates" + "github.com/alexisvisco/amigo/pkg/utils" +) + +// GenerateMainFile generate the main.go file in the amigo folder +func (a Amigo) GenerateMainFile(writer io.Writer) error { + name, err := utils.GetModuleName() + if err != nil { + return fmt.Errorf("unable to get module name: %w", err) + } + + packagePath := path.Join(name, a.ctx.MigrationFolder) + + template, err := templates.GetMainTemplate(templates.MainData{ + PackagePath: packagePath, + DriverPath: a.Driver.PackagePath(), + DriverName: a.Driver.String(), + }) + + if err != nil { + return fmt.Errorf("unable to get main template: %w", err) + } + + _, err = writer.Write([]byte(template)) + if err != nil { + return fmt.Errorf("unable to write main file: %w", err) + } + + return nil +} diff --git a/pkg/amigo/gen_migration_file.go b/pkg/amigo/gen_migration_file.go new file mode 100644 index 0000000..02830e3 --- /dev/null +++ b/pkg/amigo/gen_migration_file.go @@ -0,0 +1,64 @@ +package amigo + +import ( + "fmt" + "io" + "time" + + "github.com/alexisvisco/amigo/pkg/templates" + "github.com/alexisvisco/amigo/pkg/types" + "github.com/alexisvisco/amigo/pkg/utils" + "github.com/gobuffalo/flect" +) + +type GenerateMigrationFileParams struct { + Name string + Up string + Down string + Change string + Type types.MigrationFileType + Now time.Time + UseSchemaImport bool + UseFmtImport bool + Writer io.Writer +} + +// GenerateMigrationFile generate a migration file in the migrations folder +func (a Amigo) GenerateMigrationFile(params *GenerateMigrationFileParams) error { + + structName := utils.MigrationStructName(params.Now, params.Name) + + orDefault := func(s string) string { + if s == "" { + return "// TODO: implement the migration" + } + return s + } + + fileContent, err := templates.GetMigrationTemplate(templates.MigrationData{ + IsSQL: params.Type == types.MigrationFileTypeSQL, + Package: a.ctx.PackagePath, + StructName: structName, + Name: flect.Underscore(params.Name), + Type: params.Type, + InChange: orDefault(params.Change), + InUp: orDefault(params.Up), + InDown: orDefault(params.Down), + CreatedAt: params.Now.Format(time.RFC3339), + PackageDriverName: a.Driver.PackageName(), + PackageDriverPath: a.Driver.PackageSchemaPath(), + UseSchemaImport: params.UseSchemaImport, + UseFmtImport: params.UseFmtImport, + }) + + if err != nil { + return fmt.Errorf("unable to get migration template: %w", err) + } + + _, err = params.Writer.Write([]byte(fileContent)) + if err != nil { + return fmt.Errorf("unable to write migration file: %w", err) + } + + return nil +} diff --git a/pkg/amigo/gen_migrations_file.go b/pkg/amigo/gen_migrations_file.go new file mode 100644 index 0000000..664839a --- /dev/null +++ b/pkg/amigo/gen_migrations_file.go @@ -0,0 +1,113 @@ +package amigo + +import ( + "fmt" + "io" + "os" + "path/filepath" + "sort" + "time" + + "github.com/alexisvisco/amigo/pkg/templates" + "github.com/alexisvisco/amigo/pkg/utils" +) + +// GenerateMigrationsFiles generate the migrations file in the migrations folder +// It's used to keep track of all migrations +func (a Amigo) GenerateMigrationsFiles(writer io.Writer) error { + migrationFiles, keys, err := a.getMigrationFiles(true) + if err != nil { + return err + } + + var migrations []string + var mustImportSchemaPackage *string + for _, k := range keys { + if migrationFiles[k].isSQL { + // schema.NewSQLMigration[*pg.Schema](sqlMigrationsFS, "20240602081806_drop_index.sql", "2024-06-02T10:18:06+02:00", "---- down:"), + line := fmt.Sprintf("schema.NewSQLMigration[%s](sqlMigrationsFS, \"%s\", \"%s\", \"%s\")", + a.Driver.StructName(), + migrationFiles[k].fulName, + k.Format(time.RFC3339), + a.ctx.Create.SQLSeparator, + ) + + migrations = append(migrations, line) + + if mustImportSchemaPackage == nil { + v := a.Driver.PackageSchemaPath() + mustImportSchemaPackage = &v + } + } else { + migrations = append(migrations, fmt.Sprintf("&%s{}", utils.MigrationStructName(k, migrationFiles[k].Name))) + + } + } + + content, err := templates.GetMigrationsTemplate(templates.MigrationsData{ + Package: a.ctx.PackagePath, + Migrations: migrations, + ImportSchemaPackage: mustImportSchemaPackage, + }) + + if err != nil { + return fmt.Errorf("unable to get migrations template: %w", err) + } + + _, err = writer.Write([]byte(content)) + if err != nil { + return fmt.Errorf("unable to write migrations file: %w", err) + } + + return nil +} + +type migrationFile struct { + Name string + fulName string + isSQL bool +} + +func (a Amigo) getMigrationFiles(ascending bool) (map[time.Time]migrationFile, []time.Time, error) { + migrationFiles := make(map[time.Time]migrationFile) + + // get the list of structs by the file name + err := filepath.Walk(a.ctx.MigrationFolder, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if !info.IsDir() { + if utils.MigrationFileRegexp.MatchString(info.Name()) { + matches := utils.MigrationFileRegexp.FindStringSubmatch(info.Name()) + fileTime := matches[1] + migrationName := matches[2] + ext := matches[3] + + t, _ := time.Parse(utils.FormatTime, fileTime) + migrationFiles[t] = migrationFile{Name: migrationName, isSQL: ext == "sql", fulName: info.Name()} + } + } + + return nil + }) + if err != nil { + return nil, nil, fmt.Errorf("unable to walk through the migration folder: %w", err) + } + + // sort the files + var keys []time.Time + for k := range migrationFiles { + keys = append(keys, k) + } + + sort.Slice(keys, func(i, j int) bool { + if ascending { + return keys[i].Unix() < keys[j].Unix() + } else { + return keys[i].Unix() > keys[j].Unix() + } + }) + + return migrationFiles, keys, nil +} diff --git a/pkg/amigo/get_status.go b/pkg/amigo/get_status.go new file mode 100644 index 0000000..00f586b --- /dev/null +++ b/pkg/amigo/get_status.go @@ -0,0 +1,26 @@ +package amigo + +import ( + "database/sql" + "fmt" +) + +// GetStatus return the state of the database +func (a Amigo) GetStatus(db *sql.DB) ([]string, error) { + rows, err := db.Query("SELECT version FROM " + a.ctx.SchemaVersionTable + " ORDER BY version desc") + if err != nil { + return nil, fmt.Errorf("unable to get state: %w", err) + } + + var state []string + for rows.Next() { + var id string + err := rows.Scan(&id) + if err != nil { + return nil, fmt.Errorf("unable to scan state: %w", err) + } + state = append(state, id) + } + + return state, nil +} diff --git a/pkg/amigo/run_migration.go b/pkg/amigo/run_migration.go index 13bee53..cff6fdb 100644 --- a/pkg/amigo/run_migration.go +++ b/pkg/amigo/run_migration.go @@ -4,8 +4,10 @@ import ( "context" "database/sql" "errors" + "fmt" "io" "log/slog" + "path" "time" "github.com/alexisvisco/amigo/pkg/amigoctx" @@ -16,6 +18,8 @@ import ( "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils" "github.com/alexisvisco/amigo/pkg/utils/dblog" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" sqldblogger "github.com/simukti/sqldb-logger" ) @@ -34,6 +38,7 @@ type RunMigrationParams struct { Migrations []schema.Migration LogOutput io.Writer Context context.Context + Logger *slog.Logger } // RunMigrations migrates the database, it is launched via the generated main file or manually in a codebase. @@ -51,12 +56,7 @@ func (a Amigo) RunMigrations(params RunMigrationParams) error { ctx, cancel := context.WithDeadline(originCtx, time.Now().Add(a.ctx.Migration.Timeout)) defer cancel() - oldLogger := slog.Default() - defer func() { - slog.SetDefault(oldLogger) - }() - - a.SetupSlog(params.LogOutput) + a.SetupSlog(params.LogOutput, params.Logger) migrator, err := a.getMigrationApplier(ctx, params.DB) if err != nil { @@ -74,6 +74,22 @@ func (a Amigo) RunMigrations(params RunMigrationParams) error { return ErrMigrationFailed } + if a.ctx.Migration.DumpSchemaAfter { + file, err := utils.CreateOrOpenFile(a.ctx.SchemaOutPath) + if err != nil { + return fmt.Errorf("unable to open/create file: %w", err) + } + + defer file.Close() + + err = a.DumpSchema(file, false) + if err != nil { + return fmt.Errorf("unable to dump schema after migrating: %w", err) + } + + logger.Info(events.FileModifiedEvent{FileName: path.Join(a.ctx.SchemaOutPath)}) + } + return nil } @@ -113,6 +129,8 @@ func (a Amigo) getMigrationApplier( ContinueOnError: a.ctx.Migration.ContinueOnError, SchemaVersionTable: schema.TableName(a.ctx.SchemaVersionTable), DBLogger: recorder, + DumpSchemaFilePath: utils.NilOrValue(a.ctx.SchemaOutPath), + UseSchemaDump: a.ctx.Migration.UseSchemaDump, } switch a.Driver { diff --git a/pkg/amigo/setup_slog.go b/pkg/amigo/setup_slog.go index 2484fad..189aaaf 100644 --- a/pkg/amigo/setup_slog.go +++ b/pkg/amigo/setup_slog.go @@ -7,16 +7,16 @@ import ( "github.com/alexisvisco/amigo/pkg/utils/logger" ) -func (a Amigo) SetupSlog(writer io.Writer) { +func (a Amigo) SetupSlog(writer io.Writer, mayLogger *slog.Logger) { logger.ShowSQLEvents = a.ctx.ShowSQL - if writer == nil { + if writer == nil && mayLogger == nil { + logger.Logger = slog.New(slog.NewJSONHandler(writer, &slog.HandlerOptions{Level: slog.LevelError})) return } - if a.ctx.JSON { - slog.SetDefault(slog.New(slog.NewJSONHandler(writer, nil))) - } else { - slog.SetDefault(slog.New(logger.NewHandler(writer, nil))) + if mayLogger != nil { + logger.Logger = mayLogger + return } level := slog.LevelInfo @@ -24,5 +24,9 @@ func (a Amigo) SetupSlog(writer io.Writer) { level = slog.LevelDebug } - slog.SetLogLoggerLevel(level) + if a.ctx.JSON { + logger.Logger = slog.New(slog.NewJSONHandler(writer, &slog.HandlerOptions{Level: level})) + } else { + logger.Logger = slog.New(logger.NewHandler(writer, &logger.Options{Level: level})) + } } diff --git a/pkg/amigo/skip_migration_file.go b/pkg/amigo/skip_migration_file.go new file mode 100644 index 0000000..f9837c1 --- /dev/null +++ b/pkg/amigo/skip_migration_file.go @@ -0,0 +1,15 @@ +package amigo + +import ( + "database/sql" + "fmt" +) + +func (a Amigo) SkipMigrationFile(db *sql.DB) error { + _, err := db.Exec("INSERT INTO "+a.ctx.SchemaVersionTable+" (id) VALUES ($1)", a.ctx.Create.Version) + if err != nil { + return fmt.Errorf("unable to skip migration file: %w", err) + } + + return nil +} diff --git a/pkg/amigoctx/ctx.go b/pkg/amigoctx/ctx.go index fc94c93..bf8086e 100644 --- a/pkg/amigoctx/ctx.go +++ b/pkg/amigoctx/ctx.go @@ -16,12 +16,14 @@ var ( var ( DefaultSchemaVersionTable = "public.mig_schema_versions" - DefaultAmigoFolder = "migrations/db" - DefaultMigrationFolder = "migrations" + DefaultAmigoFolder = "db" + DefaultMigrationFolder = "db/migrations" DefaultPackagePath = "migrations" DefaultShellPath = "/bin/bash" DefaultPGDumpPath = "pg_dump" + DefaultSchemaOutPath = "db/schema.sql" DefaultTimeout = 2 * time.Minute + DefaultDBDumpSchema = "public" ) type Context struct { @@ -40,6 +42,8 @@ func NewContext() *Context { PackagePath: DefaultPackagePath, ShellPath: DefaultShellPath, PGDumpPath: DefaultPGDumpPath, + SchemaOutPath: DefaultSchemaOutPath, + SchemaDBDumpSchema: DefaultDBDumpSchema, }, Migration: &Migration{ Timeout: DefaultTimeout, @@ -60,6 +64,8 @@ type Root struct { SchemaVersionTable string ShellPath string PGDumpPath string + SchemaOutPath string + SchemaDBDumpSchema string Debug bool } @@ -102,6 +108,15 @@ func (a *Context) WithVersion(version string) *Context { return a } +func (a *Context) WithDumpSchemaAfterMigrating(dumpSchema bool) *Context { + if a.Migration == nil { + a.Migration = &Migration{} + } + a.Migration.DumpSchemaAfter = dumpSchema + + return a +} + func (a *Context) WithSteps(steps int) *Context { a.Migration.Steps = steps return a @@ -126,6 +141,8 @@ type Migration struct { DryRun bool ContinueOnError bool Timeout time.Duration + UseSchemaDump bool + DumpSchemaAfter bool } func (m *Migration) ValidateVersion() error { @@ -142,9 +159,9 @@ func (m *Migration) ValidateVersion() error { } type Create struct { - Type string - Dump bool - DumpSchema string + Type string + Dump bool + SQLSeparator string Skip bool @@ -207,6 +224,14 @@ func MergeContext(toMerge Context) *Context { if toMerge.Root.Debug { defaultCtx.Root.Debug = toMerge.Root.Debug } + + if toMerge.Root.SchemaDBDumpSchema != "" { + defaultCtx.Root.SchemaDBDumpSchema = toMerge.Root.SchemaDBDumpSchema + } + + if toMerge.Root.SchemaOutPath != "" { + defaultCtx.Root.SchemaOutPath = toMerge.Root.SchemaOutPath + } } if toMerge.Migration != nil { @@ -229,6 +254,10 @@ func MergeContext(toMerge Context) *Context { if toMerge.Migration.Timeout != 0 { defaultCtx.Migration.Timeout = toMerge.Migration.Timeout } + + if toMerge.Migration.UseSchemaDump { + defaultCtx.Migration.UseSchemaDump = toMerge.Migration.UseSchemaDump + } } if toMerge.Create != nil { @@ -240,10 +269,6 @@ func MergeContext(toMerge Context) *Context { defaultCtx.Create.Dump = toMerge.Create.Dump } - if toMerge.Create.DumpSchema != "" { - defaultCtx.Create.DumpSchema = toMerge.Create.DumpSchema - } - if toMerge.Create.Skip { defaultCtx.Create.Skip = toMerge.Create.Skip } diff --git a/pkg/entrypoint/main.go b/pkg/entrypoint/main.go index 83b31b4..88911f6 100644 --- a/pkg/entrypoint/main.go +++ b/pkg/entrypoint/main.go @@ -2,8 +2,14 @@ package entrypoint import ( "database/sql" + "encoding/base64" + "encoding/json" "flag" "fmt" + "os" + "strings" + "text/tabwriter" + "github.com/alexisvisco/amigo/pkg/amigo" "github.com/alexisvisco/amigo/pkg/amigoctx" "github.com/alexisvisco/amigo/pkg/schema" @@ -12,15 +18,11 @@ import ( "github.com/alexisvisco/amigo/pkg/utils/colors" "github.com/alexisvisco/amigo/pkg/utils/events" "github.com/alexisvisco/amigo/pkg/utils/logger" - "os" - "strings" - "text/tabwriter" - "time" ) func Main(db *sql.DB, arg amigo.MainArg, migrations []schema.Migration, ctx *amigoctx.Context) { am := amigo.NewAmigo(ctx) - am.SetupSlog(os.Stdout) + am.SetupSlog(os.Stdout, nil) switch arg { case amigo.MainArgMigrate, amigo.MainArgRollback: @@ -95,24 +97,8 @@ func sliceArrayOrDefault[T any](array []T, x int) []T { } func AmigoContextFromFlags() (*amigoctx.Context, amigo.MainArg) { - dsnFlag := flag.String("dsn", "", "URL connection to the database") - jsonFlag := flag.Bool("json", false, "Print the output in JSON") - showSQLFlag := flag.Bool("sql", false, "Print SQL statements") - schemaVersionTableFlag := flag.String("schema-version-table", "mig_schema_versions", - "Table name for the schema version") - debugFlag := flag.Bool("debug", false, "Print debug information") - - versionFlag := flag.String("version", "", "Apply or rollback a specific version") - timeoutFlag := flag.Duration("timeout", time.Minute*2, - "Timeout for the migration is the time for the whole migrations to be applied") // not working - dryRunFlag := flag.Bool("dry-run", false, "Dry run the migration will not apply the migration to the database") - continueOnErrorFlag := flag.Bool("continue-on-error", false, - "Continue on error will not rollback the migration if an error occurs") - stepsFlag := flag.Int("steps", 1, "Number of steps to rollback") - showSQLSyntaxHighlightingFlag := flag.Bool("sql-syntax-highlighting", false, - "Print SQL statements with syntax highlighting") - - // Parse flags + jsonFlag := flag.String("json", "", "all amigo context in json | bas64") + flag.Parse() if flag.NArg() == 0 { @@ -126,41 +112,19 @@ func AmigoContextFromFlags() (*amigoctx.Context, amigo.MainArg) { os.Exit(1) } - a := &amigoctx.Context{ - Root: &amigoctx.Root{ - AmigoFolderPath: "", - DSN: *dsnFlag, - JSON: *jsonFlag, - ShowSQL: *showSQLFlag, - MigrationFolder: "", - PackagePath: "", - SchemaVersionTable: *schemaVersionTableFlag, - ShellPath: "", - PGDumpPath: "", - Debug: *debugFlag, - ShowSQLSyntaxHighlighting: *showSQLSyntaxHighlightingFlag, - }, - } - - switch arg { - case amigo.MainArgMigrate: - a.Migration = &amigoctx.Migration{ - Version: *versionFlag, - DryRun: *dryRunFlag, - ContinueOnError: *continueOnErrorFlag, - Timeout: *timeoutFlag, - } - case amigo.MainArgRollback: - a.Migration = &amigoctx.Migration{ - Version: *versionFlag, - ContinueOnError: *continueOnErrorFlag, - Timeout: *timeoutFlag, - Steps: *stepsFlag, - DryRun: *dryRunFlag, + a := amigoctx.NewContext() + if *jsonFlag != "" { + b64decoded, err := base64.StdEncoding.DecodeString(*jsonFlag) + if err != nil { + logger.Error(events.MessageEvent{Message: fmt.Sprintf("unable to unmarshal amigo context b64: %s", + err.Error())}) + os.Exit(1) } - case amigo.MainArgSkipMigration: - a.Create = &amigoctx.Create{ - Version: *versionFlag, + err = json.Unmarshal(b64decoded, a) + if err != nil { + logger.Error(events.MessageEvent{Message: fmt.Sprintf("unable to unmarshal amigo context json: %s", + err.Error())}) + os.Exit(1) } } diff --git a/pkg/schema/base/base.go b/pkg/schema/base/base.go index 1da1724..7c31a0e 100644 --- a/pkg/schema/base/base.go +++ b/pkg/schema/base/base.go @@ -65,6 +65,22 @@ func (p *Schema) AddVersion(version string) { p.Context.AddVersionCreated(version) } +func (p Schema) AddVersions(versions []string) { + sql := `INSERT INTO {version_table} (version) VALUES {versions}` + replacer := utils.Replacer{ + "version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()), + "versions": utils.StrFunc(fmt.Sprintf("('%s')", strings.Join(versions, "'), ('"))), + } + + _, err := p.TX.ExecContext(p.Context.Context, replacer.Replace(sql)) + if err != nil { + p.Context.RaiseError(fmt.Errorf("error while adding versions: %w", err)) + return + } + + p.Context.AddVersionsCreated(versions) +} + // RemoveVersion removes a version from the schema_migrations table. // This function is not reversible. func (p *Schema) RemoveVersion(version string) { diff --git a/pkg/schema/detect_migrations.go b/pkg/schema/detect_migrations.go new file mode 100644 index 0000000..dc9ddd7 --- /dev/null +++ b/pkg/schema/detect_migrations.go @@ -0,0 +1,87 @@ +package schema + +import ( + "fmt" + "slices" + + "github.com/alexisvisco/amigo/pkg/types" + "github.com/alexisvisco/amigo/pkg/utils" +) + +func (m *Migrator[T]) detectMigrationsToExec( + s Schema, + migrationDirection types.MigrationDirection, + allMigrations []Migration, + version *string, + steps *int, // only used for rollback +) (migrationsToApply []Migration, firstRun bool) { + s.FindAppliedVersions() + appliedVersions, err := utils.PanicToError1(s.FindAppliedVersions) + if isTableDoesNotExists(err) { + firstRun = true + appliedVersions = []string{} + } else if err != nil { + m.ctx.RaiseError(err) + } + + var versionsToApply []Migration + var migrationsTimeFormat []string + var versionToMigration = make(map[string]Migration) + + for _, migration := range allMigrations { + migrationsTimeFormat = append(migrationsTimeFormat, migration.Date().UTC().Format(utils.FormatTime)) + versionToMigration[migrationsTimeFormat[len(migrationsTimeFormat)-1]] = migration + } + + switch migrationDirection { + case types.MigrationDirectionUp: + if version != nil && *version != "" { + if _, ok := versionToMigration[*version]; !ok { + m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) + } + + if slices.Contains(appliedVersions, *version) { + m.ctx.RaiseError(fmt.Errorf("version %s already applied", *version)) + } + + versionsToApply = append(versionsToApply, versionToMigration[*version]) + break + } + + for _, currentMigrationVersion := range migrationsTimeFormat { + if !slices.Contains(appliedVersions, currentMigrationVersion) { + versionsToApply = append(versionsToApply, versionToMigration[currentMigrationVersion]) + } + } + case types.MigrationDirectionDown: + if version != nil && *version != "" { + if _, ok := versionToMigration[*version]; !ok { + m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) + } + + if !slices.Contains(appliedVersions, *version) { + m.ctx.RaiseError(fmt.Errorf("version %s not applied", *version)) + } + + versionsToApply = append(versionsToApply, versionToMigration[*version]) + break + } + + step := 1 + if steps != nil && *steps > 0 { + step = *steps + } + + for i := len(allMigrations) - 1; i >= 0; i-- { + if slices.Contains(appliedVersions, migrationsTimeFormat[i]) { + versionsToApply = append(versionsToApply, versionToMigration[migrationsTimeFormat[i]]) + } + + if len(versionsToApply) == step { + break + } + } + } + + return versionsToApply, firstRun +} diff --git a/pkg/schema/helpers.go b/pkg/schema/helpers.go new file mode 100644 index 0000000..8d1aa72 --- /dev/null +++ b/pkg/schema/helpers.go @@ -0,0 +1,24 @@ +package schema + +import "regexp" + +func isTableDoesNotExists(err error) bool { + if err == nil { + return false + } + + re := []*regexp.Regexp{ + regexp.MustCompile(`Error 1146 \(42S02\): Table '.*' doesn't exist`), + regexp.MustCompile(`ERROR: relation ".*" does not exist \(SQLSTATE 42P01\)`), + regexp.MustCompile(`no such table: .*`), + regexp.MustCompile(`.*does not exist \(SQLSTATE=42P01\).*`), + } + + for _, r := range re { + if r.MatchString(err.Error()) { + return true + } + } + + return false +} diff --git a/pkg/schema/migrator_test.go b/pkg/schema/helpers_test.go similarity index 100% rename from pkg/schema/migrator_test.go rename to pkg/schema/helpers_test.go diff --git a/pkg/schema/migrator.go b/pkg/schema/migrator.go index 82821e7..6ffff2a 100644 --- a/pkg/schema/migrator.go +++ b/pkg/schema/migrator.go @@ -4,8 +4,6 @@ import ( "context" "fmt" "reflect" - "regexp" - "slices" "time" "github.com/alexisvisco/amigo/pkg/types" @@ -28,6 +26,12 @@ type MigratorOption struct { SchemaVersionTable TableName DBLogger dblog.DatabaseLogger + + // DumpSchemaFilePath is the path to the schema dump file. + DumpSchemaFilePath *string + + // UseSchemaDump specifies if the migrator should use the schema. (if possible -> for fresh installation) + UseSchemaDump bool } // Migration is the interface that describes a migration at is simplest form. @@ -83,7 +87,7 @@ func NewMigrator[T Schema]( func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, steps *int, migrations []Migration) bool { db := m.schemaFactory(m.ctx, m.db, m.db) - migrationsToExecute := m.findMigrationsToExecute( + migrationsToExecute, firstRun := m.detectMigrationsToExec( db, direction, migrations, @@ -96,6 +100,18 @@ func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, return true } + if firstRun && m.ctx.MigratorOptions.UseSchemaDump { + logger.Info(events.MessageEvent{Message: "We detect a fresh installation and applied the schema dump"}) + err := m.tryMigrateWithSchemaDump(migrationsToExecute) + if err != nil { + logger.Error(events.MessageEvent{Message: fmt.Sprintf("unable to apply schema dump: %v", err)}) + return false + } + + logger.Info(events.MessageEvent{Message: "Schema dump applied successfully"}) + return true + } + m.ToggleDBLog(true) defer m.ToggleDBLog(false) @@ -143,141 +159,6 @@ func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, return true } -func (m *Migrator[T]) findMigrationsToExecute( - s Schema, - migrationDirection types.MigrationDirection, - allMigrations []Migration, - version *string, - steps *int, // only used for rollback -) []Migration { - appliedVersions, err := utils.PanicToError1(s.FindAppliedVersions) - if isTableDoesNotExists(err) { - appliedVersions = []string{} - } else if err != nil { - m.ctx.RaiseError(err) - } - - var versionsToApply []Migration - var migrationsTimeFormat []string - var versionToMigration = make(map[string]Migration) - - for _, migration := range allMigrations { - migrationsTimeFormat = append(migrationsTimeFormat, migration.Date().UTC().Format(utils.FormatTime)) - versionToMigration[migrationsTimeFormat[len(migrationsTimeFormat)-1]] = migration - } - - switch migrationDirection { - case types.MigrationDirectionUp: - if version != nil && *version != "" { - if _, ok := versionToMigration[*version]; !ok { - m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) - } - - if slices.Contains(appliedVersions, *version) { - m.ctx.RaiseError(fmt.Errorf("version %s already applied", *version)) - } - - versionsToApply = append(versionsToApply, versionToMigration[*version]) - break - } - - for _, currentMigrationVersion := range migrationsTimeFormat { - if !slices.Contains(appliedVersions, currentMigrationVersion) { - versionsToApply = append(versionsToApply, versionToMigration[currentMigrationVersion]) - } - } - case types.MigrationDirectionDown: - if version != nil && *version != "" { - if _, ok := versionToMigration[*version]; !ok { - m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) - } - - if !slices.Contains(appliedVersions, *version) { - m.ctx.RaiseError(fmt.Errorf("version %s not applied", *version)) - } - - versionsToApply = append(versionsToApply, versionToMigration[*version]) - break - } - - step := 1 - if steps != nil && *steps > 0 { - step = *steps - } - - for i := len(allMigrations) - 1; i >= 0; i-- { - if slices.Contains(appliedVersions, migrationsTimeFormat[i]) { - versionsToApply = append(versionsToApply, versionToMigration[migrationsTimeFormat[i]]) - } - - if len(versionsToApply) == step { - break - } - } - } - - return versionsToApply -} - -// run runs the migration. -func (m *Migrator[T]) run(migrationType types.MigrationDirection, version string, f func(T)) (ok bool) { - currentContext := m.ctx - currentContext.MigrationDirection = migrationType - - tx, err := m.db.BeginTx(currentContext.Context, nil) - if err != nil { - logger.Error(events.MessageEvent{Message: "unable to start transaction"}) - return false - } - - schema := m.schemaFactory(currentContext, tx, m.db) - - handleError := func(err any) { - if err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("migration failed, rollback due to: %v", err)}) - - err := tx.Rollback() - if err != nil { - logger.Error(events.MessageEvent{Message: "unable to rollback transaction"}) - } - - ok = false - } - } - - defer func() { - if r := recover(); r != nil { - handleError(r) - } - }() - - f(schema) - - switch migrationType { - case types.MigrationDirectionUp: - schema.AddVersion(version) - case types.MigrationDirectionDown, types.MigrationDirectionNotReversible: - schema.RemoveVersion(version) - } - - if m.ctx.MigratorOptions.DryRun { - logger.Info(events.MessageEvent{Message: "migration in dry run mode, rollback transaction..."}) - err := tx.Rollback() - if err != nil { - logger.Error(events.MessageEvent{Message: "unable to rollback transaction"}) - } - return true - } else { - err := tx.Commit() - if err != nil { - logger.Error(events.MessageEvent{Message: "unable to commit transaction"}) - return false - } - } - - return true -} - func (m *Migrator[T]) NewSchema() T { return m.schemaFactory(m.ctx, m.db, m.db) } @@ -292,24 +173,3 @@ func (m *Migrator[T]) ToggleDBLog(b bool) { m.Options().DBLogger.ToggleLogger(b) } } - -func isTableDoesNotExists(err error) bool { - if err == nil { - return false - } - - re := []*regexp.Regexp{ - regexp.MustCompile(`Error 1146 \(42S02\): Table '.*' doesn't exist`), - regexp.MustCompile(`ERROR: relation ".*" does not exist \(SQLSTATE 42P01\)`), - regexp.MustCompile(`no such table: .*`), - regexp.MustCompile(`.*does not exist \(SQLSTATE=42P01\).*`), - } - - for _, r := range re { - if r.MatchString(err.Error()) { - return true - } - } - - return false -} diff --git a/pkg/schema/migrator_context.go b/pkg/schema/migrator_context.go index 0426c82..dbcbbd4 100644 --- a/pkg/schema/migrator_context.go +++ b/pkg/schema/migrator_context.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils/events" "github.com/alexisvisco/amigo/pkg/utils/logger" @@ -176,6 +177,14 @@ func (m *MigratorContext) AddVersionCreated(version string) { logger.Info(events.VersionAddedEvent{Version: version}) } +func (m *MigratorContext) AddVersionsCreated(versions []string) { + for _, version := range versions { + m.MigrationEvents.versionCreated = append(m.MigrationEvents.versionCreated, version) + } + + // no logging here, special case +} + func (m *MigratorContext) AddVersionDeleted(version string) { m.MigrationEvents.versionDeleted = append(m.MigrationEvents.versionDeleted, version) logger.Info(events.VersionDeletedEvent{Version: version}) diff --git a/pkg/schema/run_migration.go b/pkg/schema/run_migration.go new file mode 100644 index 0000000..b5c2db1 --- /dev/null +++ b/pkg/schema/run_migration.go @@ -0,0 +1,68 @@ +package schema + +import ( + "fmt" + + "github.com/alexisvisco/amigo/pkg/types" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" +) + +// run runs the migration. +func (m *Migrator[T]) run(migrationType types.MigrationDirection, version string, f func(T)) (ok bool) { + currentContext := m.ctx + currentContext.MigrationDirection = migrationType + + tx, err := m.db.BeginTx(currentContext.Context, nil) + if err != nil { + logger.Error(events.MessageEvent{Message: "unable to start transaction"}) + return false + } + + schema := m.schemaFactory(currentContext, tx, m.db) + + handleError := func(err any) { + if err != nil { + logger.Error(events.MessageEvent{Message: fmt.Sprintf("migration failed, rollback due to: %v", err)}) + + err := tx.Rollback() + if err != nil { + logger.Error(events.MessageEvent{Message: "unable to rollback transaction"}) + } + + ok = false + } + } + + defer func() { + if r := recover(); r != nil { + handleError(r) + } + }() + + f(schema) + + switch migrationType { + case types.MigrationDirectionUp: + schema.AddVersion(version) + case types.MigrationDirectionDown, types.MigrationDirectionNotReversible: + schema.RemoveVersion(version) + } + + if m.ctx.MigratorOptions.DryRun { + logger.Info(events.MessageEvent{Message: "migration in dry run mode, rollback transaction..."}) + err := tx.Rollback() + if err != nil { + logger.Error(events.MessageEvent{Message: "unable to rollback transaction"}) + } + return true + } else { + err := tx.Commit() + if err != nil { + logger.Error(events.MessageEvent{Message: "unable to commit transaction"}) + return false + } + } + + return true +} diff --git a/pkg/schema/run_migration_schema_dump.go b/pkg/schema/run_migration_schema_dump.go new file mode 100644 index 0000000..3e17b62 --- /dev/null +++ b/pkg/schema/run_migration_schema_dump.go @@ -0,0 +1,54 @@ +package schema + +import ( + "errors" + "fmt" + "os" + + "github.com/alexisvisco/amigo/pkg/utils" + "github.com/alexisvisco/amigo/pkg/utils/logger" +) + +// tryMigrateWithSchemaDump tries to migrate with schema dump. +// this might be executed when the user arrives on a repo with a schema.sql, instead of running +// all the migrations we will try to dump the schema and apply it. Then tell we applied all versions. +func (m *Migrator[T]) tryMigrateWithSchemaDump(migrations []Migration) error { + if m.ctx.MigratorOptions.DumpSchemaFilePath == nil { + return errors.New("no schema dump file path provided") + } + + file, err := os.ReadFile(*m.ctx.MigratorOptions.DumpSchemaFilePath) + if err != nil { + return fmt.Errorf("unable to read schema dump file: %w", err) + } + + logger.ShowSQLEvents = false + + tx, err := m.db.BeginTx(m.ctx.Context, nil) + if err != nil { + return fmt.Errorf("unable to start transaction: %w", err) + } + + defer tx.Rollback() + + tx.ExecContext(m.ctx.Context, "SET search_path TO public") + _, err = tx.ExecContext(m.ctx.Context, string(file)) + if err != nil { + return fmt.Errorf("unable to apply schema dump: %w", err) + } + + tx.Commit() + + schema := m.NewSchema() + + versions := make([]string, 0, len(migrations)) + for _, migration := range migrations { + versions = append(versions, fmt.Sprint(migration.Date().UTC().Format(utils.FormatTime))) + } + + logger.ShowSQLEvents = false + + schema.AddVersions(versions) + + return nil +} diff --git a/pkg/schema/schema.go b/pkg/schema/schema.go index 1eda97b..58ef638 100644 --- a/pkg/schema/schema.go +++ b/pkg/schema/schema.go @@ -5,6 +5,7 @@ import "database/sql" // Schema is the interface that need to be implemented to support migrations. type Schema interface { AddVersion(version string) + AddVersions(versions []string) RemoveVersion(version string) FindAppliedVersions() []string diff --git a/pkg/schema/options.go b/pkg/schema/schema_options.go similarity index 100% rename from pkg/schema/options.go rename to pkg/schema/schema_options.go diff --git a/pkg/schema/sqlite/column.go b/pkg/schema/sqlite/column.go index fa1bd6f..7f3ada0 100644 --- a/pkg/schema/sqlite/column.go +++ b/pkg/schema/sqlite/column.go @@ -94,7 +94,6 @@ func (p *Schema) column(options schema.ColumnOptions) string { s += options.ColumnType } s = "PRIMARY KEY" - fmt.Println(options.ColumnType) if options.ColumnType == schema.ColumnTypeSerial || options.ColumnType == schema.ColumnTypePrimaryKey { s += " AUTOINCREMENT" } diff --git a/pkg/utils/files.go b/pkg/utils/files.go index fe185a3..d9845ce 100644 --- a/pkg/utils/files.go +++ b/pkg/utils/files.go @@ -16,3 +16,13 @@ func CreateOrOpenFile(path string) (*os.File, error) { // create or open file return os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) } + +func EnsurePrentDirExists(path string) error { + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + if !os.IsExist(err) { + return fmt.Errorf("unable to create parent directory: %w", err) + } + } + + return nil +} diff --git a/pkg/utils/logger/slog.go b/pkg/utils/logger/slog.go index 46cf904..0bc4930 100644 --- a/pkg/utils/logger/slog.go +++ b/pkg/utils/logger/slog.go @@ -5,7 +5,6 @@ import ( "context" "encoding" "fmt" - "github.com/alexisvisco/amigo/pkg/utils/events" "io" "log/slog" "path/filepath" @@ -15,6 +14,8 @@ import ( "sync" "time" "unicode" + + "github.com/alexisvisco/amigo/pkg/utils/events" ) const errKey = "err" @@ -24,7 +25,10 @@ var ( defaultTimeFormat = time.StampMilli ) -var ShowSQLEvents = false +var ( + ShowSQLEvents = false + Logger = slog.Default() +) // Options for a slog.Handler that writes tinted logs. A zero Options consists // entirely of default values. @@ -392,7 +396,7 @@ func event(event any) *slog.Logger { name = en.EventName() } - return slog.With(slog.Any("event", event), slog.String("event_name", name)) + return Logger.With(slog.Any("event", event), slog.String("event_name", name)) } func Info(evt any) {