diff --git a/cmd/context.go b/cmd/context.go deleted file mode 100644 index af1048f..0000000 --- a/cmd/context.go +++ /dev/null @@ -1,35 +0,0 @@ -package cmd - -import ( - "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/spf13/cobra" - "github.com/spf13/viper" -) - -const contextFileName = "config.yml" - -// contextCmd represents the context command -var contextCmd = &cobra.Command{ - Use: "context", - Short: "save flags into a context", - Long: `A context is a file inside the amigo folder that contains the flags that you use in the command line. - -Example: - amigo context --dsn "postgres://user:password@host:port/dbname?sslmode=disable" - -This command will create a file $amigo_folder/context.yaml with the content: - dsn: "postgres://user:password@host:port/dbname?sslmode=disable" -`, - Run: wrapCobraFunc(func(cmd *cobra.Command, _ amigo.Amigo, args []string) error { - err := viper.WriteConfig() - if err != nil { - return err - } - - return nil - }), -} - -func init() { - rootCmd.AddCommand(contextCmd) -} diff --git a/cmd/init.go b/cmd/init.go index 29f9e1d..4dd83f6 100644 --- a/cmd/init.go +++ b/cmd/init.go @@ -2,87 +2,105 @@ package cmd import ( "fmt" + "path" + "time" + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/templates" "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils" "github.com/alexisvisco/amigo/pkg/utils/events" "github.com/alexisvisco/amigo/pkg/utils/logger" - "github.com/spf13/cobra" - "path" - "time" + "gopkg.in/yaml.v3" ) -// initCmd represents the init command -var initCmd = &cobra.Command{ - Use: "init", - Short: "Initialize migrations folder and add the first migration file", - Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { - if err := cmdCtx.ValidateDSN(); err != nil { - return err - } - - // create the main file - logger.Info(events.FolderAddedEvent{FolderName: cmdCtx.MigrationFolder}) - - file, err := utils.CreateOrOpenFile(path.Join(cmdCtx.AmigoFolderPath, "main.go")) - if err != nil { - return fmt.Errorf("unable to open main.go file: %w", err) - } - - err = am.GenerateMainFile(file) - if err != nil { - return err - } - - logger.Info(events.FileAddedEvent{FileName: path.Join(cmdCtx.AmigoFolderPath, "main.go")}) - - // create the base schema version table - now := time.Now() - migrationFileName := fmt.Sprintf("%s_create_table_schema_version.go", now.UTC().Format(utils.FormatTime)) - file, err = utils.CreateOrOpenFile(path.Join(cmdCtx.MigrationFolder, migrationFileName)) - if err != nil { - return fmt.Errorf("unable to open migrations.go file: %w", err) - } - - inUp, err := templates.GetInitCreateTableTemplate(templates.CreateTableData{Name: cmdCtx.SchemaVersionTable}, - am.Driver == types.DriverUnknown) - if err != nil { - return err - } - - err = am.GenerateMigrationFile(&amigo.GenerateMigrationFileParams{ - Name: "create_table_schema_version", - Up: inUp, - Down: "// nothing to do to keep the schema version table", - Type: types.MigrationFileTypeClassic, - Now: now, - Writer: file, - UseSchemaImport: am.Driver != types.DriverUnknown, - UseFmtImport: am.Driver == types.DriverUnknown, - }) - if err != nil { - return err - } - logger.Info(events.FileAddedEvent{FileName: path.Join(cmdCtx.MigrationFolder, migrationFileName)}) - - // create the migrations file where all the migrations will be stored - file, err = utils.CreateOrOpenFile(path.Join(cmdCtx.MigrationFolder, migrationsFile)) - if err != nil { - return err - } - - err = am.GenerateMigrationsFiles(file) - if err != nil { - return err - } - - logger.Info(events.FileAddedEvent{FileName: path.Join(cmdCtx.MigrationFolder, migrationsFile)}) - - return nil - }), -} +func executeInit( + mainFilePath, + amigoFolder, + table, + migrationsFolder string, +) error { + // create the main file + logger.Info(events.FolderAddedEvent{FolderName: amigoFolder}) + + file, err := utils.CreateOrOpenFile(mainFilePath) + if err != nil { + return fmt.Errorf("unable to open main.go file: %w", err) + } + + cfg := amigoconfig.NewConfig(). + WithAmigoFolder(amigoFolder). + WithMigrationFolder(migrationsFolder). + WithSchemaVersionTable(table) + + am := amigo.NewAmigo(cfg) + + err = am.GenerateMainFile(file) + if err != nil { + return err + } + + logger.Info(events.FileAddedEvent{FileName: mainFilePath}) + + // create the base schema version table + now := time.Now() + migrationFileName := fmt.Sprintf("%s_create_table_schema_version.go", now.UTC().Format(utils.FormatTime)) + file, err = utils.CreateOrOpenFile(path.Join(cfg.MigrationFolder, migrationFileName)) + if err != nil { + return fmt.Errorf("unable to open migrationsFolder.go file: %w", err) + } + + inUp, err := templates.GetInitCreateTableTemplate(templates.CreateTableData{Name: table}, + am.Driver == types.DriverUnknown) + if err != nil { + return err + } + + err = am.GenerateMigrationFile(&amigo.GenerateMigrationFileParams{ + Name: "create_table_schema_version", + Up: inUp, + Down: "// nothing to do to keep the schema version table", + Type: types.MigrationFileTypeClassic, + Now: now, + Writer: file, + UseSchemaImport: am.Driver != types.DriverUnknown, + UseFmtImport: am.Driver == types.DriverUnknown, + }) + if err != nil { + return err + } + logger.Info(events.FileAddedEvent{FileName: path.Join(migrationsFolder, migrationFileName)}) + + // create the migrationsFolder file where all the migrationsFolder will be stored + file, err = utils.CreateOrOpenFile(path.Join(amigoFolder, "migrations.go")) + if err != nil { + return err + } + + err = am.GenerateMigrationsFiles(file) + if err != nil { + return err + } + + logger.Info(events.FileAddedEvent{FileName: path.Join(amigoFolder, migrationFileName)}) + + // write the context file + out, err := yaml.Marshal(amigoconfig.DefaultYamlConfig) + if err != nil { + return err + } + + openFile, err := utils.CreateOrOpenFile(path.Join(amigoFolder, "contexts.yaml")) + if err != nil { + return fmt.Errorf("unable to open contexts.yaml file: %w", err) + } + defer openFile.Close() + + _, err = openFile.WriteString(string(out)) + if err != nil { + return fmt.Errorf("unable to write contexts.yaml file: %w", err) + } -func init() { - rootCmd.AddCommand(initCmd) + return nil } diff --git a/cmd/migration.go b/cmd/migration.go deleted file mode 100644 index f8a4383..0000000 --- a/cmd/migration.go +++ /dev/null @@ -1,61 +0,0 @@ -package cmd - -import ( - "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" - "github.com/spf13/cobra" -) - -// migrateCmd represents the up command -var migrateCmd = &cobra.Command{ - Use: "migrate", - Short: "Apply the database", - Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { - if err := cmdCtx.ValidateDSN(); err != nil { - return err - } - - if err := am.ExecuteMain(amigo.MainArgMigrate); err != nil { - return err - } - - return nil - }), -} - -// rollbackCmd represents the down command -var rollbackCmd = &cobra.Command{ - Use: "rollback", - Short: "Rollback the database", - Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { - if err := cmdCtx.ValidateDSN(); err != nil { - return err - } - - return am.ExecuteMain(amigo.MainArgRollback) - }), -} - -func init() { - rootCmd.AddCommand(rollbackCmd) - rootCmd.AddCommand(migrateCmd) - - registerBase := func(cmd *cobra.Command, m *amigoctx.Migration) { - cmd.Flags().StringVar(&m.Version, "version", "", - "Apply a specific version format: 20240502083700 or 20240502083700_name.go") - cmd.Flags().BoolVar(&m.DryRun, "dry-run", false, "Run the migrations without applying them") - cmd.Flags().BoolVar(&m.ContinueOnError, "continue-on-error", false, - "Will not rollback the migration if an error occurs") - cmd.Flags().DurationVar(&m.Timeout, "timeout", amigoctx.DefaultTimeout, "The timeout for the migration") - cmd.Flags().BoolVarP(&m.DumpSchemaAfter, "dump-schema-after", "d", false, - "Dump schema after migrate/rollback (not compatible with --use-schema-dump)") - } - - registerBase(migrateCmd, cmdCtx.Migration) - migrateCmd.Flags().BoolVar(&cmdCtx.Migration.UseSchemaDump, "use-schema-dump", false, - "Use the schema file to apply the migration (for fresh install without any migration)") - - registerBase(rollbackCmd, cmdCtx.Migration) - rollbackCmd.Flags().IntVar(&cmdCtx.Migration.Steps, "steps", 1, "The number of steps to rollback") - -} diff --git a/cmd/root.go b/cmd/root.go index a587cd9..aee511e 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -3,199 +3,78 @@ package cmd import ( "fmt" "os" - "path/filepath" - - "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" - "github.com/alexisvisco/amigo/pkg/utils/events" - "github.com/alexisvisco/amigo/pkg/utils/logger" - "github.com/spf13/viper" + "path" + "slices" + "strings" + "github.com/alexisvisco/amigo/pkg/amigoconfig" + "github.com/alexisvisco/amigo/pkg/utils/cmdexec" "github.com/spf13/cobra" ) -var cmdCtx = amigoctx.NewContext() - -const ( - migrationsFile = "migrations.go" -) +var cmdConfig = &amigoconfig.Config{} // rootCmd represents the base command when called without any subcommands var rootCmd = &cobra.Command{ - Use: "amigo", - Short: "Tool to manage database migrations with go files", - Long: `Basic usage: -First you need to create a main folder with amigo init: - - will create a folder in db with a context file inside to not have to pass the dsn every time. - - Postgres: - $ amigo context --dsn "postgres://user:password@host:port/dbname?sslmode=disable" - - SQLite: - $ amigo context --dsn "sqlite:/path/to/db.sqlite" --schema-version-table mig_schema_versions - - Unknown Driver (Mysql in this case): - $ amigo context --dsn "user:password@tcp(host:port)/dbname" - - - $ amigo init - note: will create: - - folder named migrations with a file named migrations.go that contains the list of migrations - - a new migration to create the schema version table - - a main.go in the $amigo_folder - -Apply migrations: - $ amigo migrate - note: you can set --version to migrate a specific version - -Create a new migration: - $ amigo create "create_table_users" - note: you can set --dump if you already have a database and you want to create the first migration with what's - already in the database. --skip will add the version of the created migration inside the schema version table. - -Rollback a migration: - $ amigo rollback - note: you can set --step to rollback a specific number of migrations, and --version to rollback - to a specific version -`, - SilenceUsage: true, -} - -func Execute() { - _ = rootCmd.Execute() -} - -func init() { - rootCmd.PersistentFlags().StringVarP(&cmdCtx.AmigoFolderPath, "amigo-folder", "m", amigoctx.DefaultAmigoFolder, - "Folder path to use for creating amigo related files related to this repository") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.DSN, "dsn", "", - "The database connection string example: postgres://user:password@host:port/dbname?sslmode=disable") - - rootCmd.PersistentFlags().BoolVarP(&cmdCtx.JSON, "json", "j", false, "Output in json format") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.MigrationFolder, "folder", amigoctx.DefaultMigrationFolder, - "Folder where the migrations are stored") - - rootCmd.PersistentFlags().StringVarP(&cmdCtx.PackagePath, "package", "p", amigoctx.DefaultPackagePath, - "Package name of the migrations folder") - - rootCmd.PersistentFlags().StringVarP(&cmdCtx.SchemaVersionTable, "schema-version-table", "t", - amigoctx.DefaultSchemaVersionTable, "Table name to keep track of the migrations") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.ShellPath, "shell-path", amigoctx.DefaultShellPath, - "Shell to use (for: amigo create --dump, it uses pg dump command)") - - rootCmd.PersistentFlags().BoolVar(&cmdCtx.ShowSQL, "sql", false, "Print SQL queries") - - rootCmd.PersistentFlags().BoolVar(&cmdCtx.ShowSQLSyntaxHighlighting, "sql-syntax-highlighting", true, - "Print SQL queries with syntax highlighting") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.SchemaOutPath, "schema-out-path", amigoctx.DefaultSchemaOutPath, - "File path of the schema dump if any") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.PGDumpPath, "pg-dump-path", amigoctx.DefaultPGDumpPath, - "Path to the pg_dump command if --dump is set") - - rootCmd.PersistentFlags().StringVar(&cmdCtx.SchemaDBDumpSchema, "schema-db-dump-schema", - amigoctx.DefaultDBDumpSchema, "Schema to use when dumping schema") - - rootCmd.PersistentFlags().BoolVar(&cmdCtx.Debug, "debug", false, "Print debug information") - initConfig() -} - -func initConfig() { - // check if the file exists, if the file does not exist, create it - if _, err := os.Stat(filepath.Join(cmdCtx.AmigoFolderPath, contextFileName)); os.IsNotExist(err) { - err := os.MkdirAll(cmdCtx.AmigoFolderPath, 0755) - if err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("error: can't create folder: %s", err)}) - os.Exit(1) - } - if err := viper.WriteConfigAs(filepath.Join(cmdCtx.AmigoFolderPath, contextFileName)); err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("error: can't write config: %s", err)}) - os.Exit(1) + Use: "amigo", + Short: "Tool to manage database migrations with go files", + SilenceUsage: true, + DisableFlagParsing: true, + RunE: func(cmd *cobra.Command, args []string) error { + shellPath := amigoconfig.DefaultShellPath + defaultAmigoFolder := amigoconfig.DefaultAmigoFolder + + if env, ok := os.LookupEnv("AMIGO_FOLDER"); ok { + defaultAmigoFolder = env } - } - _ = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn")) - _ = viper.BindPFlag("json", rootCmd.PersistentFlags().Lookup("json")) - _ = viper.BindPFlag("folder", rootCmd.PersistentFlags().Lookup("folder")) - _ = viper.BindPFlag("package", rootCmd.PersistentFlags().Lookup("package")) - _ = viper.BindPFlag("schema-version-table", rootCmd.PersistentFlags().Lookup("schema-version-table")) - _ = viper.BindPFlag("shell-path", rootCmd.PersistentFlags().Lookup("shell-path")) - _ = viper.BindPFlag("pg-dump-path", createCmd.Flags().Lookup("pg-dump-path")) - _ = viper.BindPFlag("sql", rootCmd.PersistentFlags().Lookup("sql")) - _ = viper.BindPFlag("sql-syntax-highlighting", rootCmd.PersistentFlags().Lookup("sql-syntax-highlighting")) - _ = viper.BindPFlag("schema-out-path", rootCmd.Flags().Lookup("schema-out-path")) - _ = viper.BindPFlag("schema-db-dump-schema", rootCmd.Flags().Lookup("schema-db-dump-schema")) - _ = viper.BindPFlag("debug", rootCmd.PersistentFlags().Lookup("debug")) - - viper.SetConfigFile(filepath.Join(cmdCtx.AmigoFolderPath, contextFileName)) - - if err := viper.ReadInConfig(); err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("error: can't read config: %s", err)}) - os.Exit(1) - } - - if viper.IsSet("dsn") { - cmdCtx.DSN = viper.GetString("dsn") - } + schemaVersionTable := amigoconfig.DefaultSchemaVersionTable + mainFilePath := path.Join(defaultAmigoFolder, "main.go") + mainBinaryPath := path.Join(defaultAmigoFolder, "main") + migrationFolder := amigoconfig.DefaultMigrationFolder - if viper.IsSet("json") { - cmdCtx.JSON = viper.GetBool("json") - } - - if viper.IsSet("folder") { - cmdCtx.MigrationFolder = viper.GetString("folder") - } - - if viper.IsSet("package") { - cmdCtx.PackagePath = viper.GetString("package") - } + if slices.Contains(args, "init") { + return executeInit(mainFilePath, defaultAmigoFolder, schemaVersionTable, migrationFolder) + } - if viper.IsSet("schema-version-table") { - cmdCtx.SchemaVersionTable = viper.GetString("schema-version-table") - } + return executeMain(shellPath, mainFilePath, mainBinaryPath, args) + }, +} - if viper.IsSet("shell-path") { - cmdCtx.ShellPath = viper.GetString("shell-path") - } +func Execute() { + _ = rootCmd.Execute() +} - if viper.IsSet("pg-dump-path") { - cmdCtx.PGDumpPath = viper.GetString("pg-dump-path") +func executeMain(shellPath, mainFilePath, mainBinaryPath string, restArgs []string) error { + _, err := os.Stat(mainFilePath) + if os.IsNotExist(err) { + return fmt.Errorf("%s file does not exist, please run 'amigo init' to create it", mainFilePath) } - if viper.IsSet("sql") { - cmdCtx.ShowSQL = viper.GetBool("sql") + // build binary + args := []string{ + "go", "build", + "-o", mainBinaryPath, + mainFilePath, } - if viper.IsSet("sql-syntax-highlighting") { - cmdCtx.ShowSQLSyntaxHighlighting = viper.GetBool("sql-syntax-highlighting") + err = cmdexec.ExecToWriter(shellPath, []string{"-c", strings.Join(args, " ")}, nil, os.Stdout, os.Stderr) + if err != nil { + return err } - if viper.IsSet("debug") { - cmdCtx.Debug = viper.GetBool("debug") + args = []string{ + "./" + mainBinaryPath, } - if viper.IsSet("schema-out-path") { - cmdCtx.SchemaOutPath = viper.GetString("schema-out-path") + if len(restArgs) > 0 { + args = append(args, restArgs...) } - if viper.IsSet("schema-db-dump-schema") { - cmdCtx.SchemaDBDumpSchema = viper.GetString("schema-db-dump-schema") + err = cmdexec.ExecToWriter(shellPath, []string{"-c", strings.Join(args, " ")}, nil, os.Stdout, os.Stderr) + if err != nil { + return fmt.Errorf("%s throw an error: %w", mainFilePath, err) } -} - -func wrapCobraFunc(f func(cmd *cobra.Command, am amigo.Amigo, args []string) error) func(cmd *cobra.Command, args []string) { - return func(cmd *cobra.Command, args []string) { - am := amigo.NewAmigo(cmdCtx) - am.SetupSlog(os.Stdout, nil) - if err := f(cmd, am, args); err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) - } - } + return nil } diff --git a/cmd/status.go b/cmd/status.go deleted file mode 100644 index 6bf0fe0..0000000 --- a/cmd/status.go +++ /dev/null @@ -1,22 +0,0 @@ -package cmd - -import ( - "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/spf13/cobra" -) - -var statusCmd = &cobra.Command{ - Use: "status", - Short: "Status explain the current state of the database.", - Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { - if err := cmdCtx.ValidateDSN(); err != nil { - return err - } - - return am.ExecuteMain(amigo.MainArgStatus) - }), -} - -func init() { - rootCmd.AddCommand(statusCmd) -} diff --git a/docs/docs/04-api/100-migrating-in-go.md b/docs/docs/04-api/100-migrating-in-go.md index 7c168e1..e659ce6 100644 --- a/docs/docs/04-api/100-migrating-in-go.md +++ b/docs/docs/04-api/100-migrating-in-go.md @@ -9,7 +9,7 @@ import ( "database/sql" "example/pg/db/migrations" "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/types" _ "github.com/jackc/pgx/v5/stdlib" "os" diff --git a/e2e_test.go b/e2e_test.go index ade05bd..683bef2 100644 --- a/e2e_test.go +++ b/e2e_test.go @@ -8,7 +8,7 @@ import ( "testing" "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/schema" "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils" @@ -75,7 +75,7 @@ func createSchema(t *testing.T, connection *sql.DB, s string) { // then rollback all migrations, check the snapshot with the first one // then up all migrations, check the snapshot with the last one func ensureMigrationsAreReversible(t *testing.T, db schema.DatabaseCredentials, migrations []schema.Migration, sql *sql.DB, dsn, schema string) { - actx := amigoctx.NewContext() + actx := amigoconfig.NewConfig() actx.ShowSQL = true actx.DSN = dsn actx.SchemaVersionTable = schema + ".mig_schema_versions" diff --git a/example/pg/main.go b/example/pg/main.go index fba4fbd..c7ae20c 100644 --- a/example/pg/main.go +++ b/example/pg/main.go @@ -4,7 +4,7 @@ import ( "database/sql" "example/pg/migrations" "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/types" _ "github.com/jackc/pgx/v5/stdlib" "os" @@ -18,7 +18,7 @@ func main() { panic(err) } - err = amigo.NewAmigo(amigoctx.NewContext().WithDSN(dsn)).RunMigrations(amigo.RunMigrationParams{ + err = amigo.NewAmigo(amigoconfig.NewContext().WithDSN(dsn)).RunMigrations(amigo.RunMigrationParams{ DB: db, Direction: types.MigrationDirectionDown, Migrations: migrations.Migrations, diff --git a/pkg/amigo/amigo.go b/pkg/amigo/amigo.go index c1bc1c9..9817e48 100644 --- a/pkg/amigo/amigo.go +++ b/pkg/amigo/amigo.go @@ -1,19 +1,82 @@ package amigo import ( - "github.com/alexisvisco/amigo/pkg/amigoctx" + "context" + "database/sql" + + "github.com/alexisvisco/amigo/pkg/amigoconfig" + "github.com/alexisvisco/amigo/pkg/schema" + "github.com/alexisvisco/amigo/pkg/schema/base" + "github.com/alexisvisco/amigo/pkg/schema/pg" + "github.com/alexisvisco/amigo/pkg/schema/sqlite" "github.com/alexisvisco/amigo/pkg/types" + "github.com/alexisvisco/amigo/pkg/utils/dblog" + sqldblogger "github.com/simukti/sqldb-logger" ) type Amigo struct { - ctx *amigoctx.Context - Driver types.Driver + Config *amigoconfig.Config + Driver types.Driver + CustomSchemaFactory schema.Factory[schema.Schema] +} + +type OptionFn func(a *Amigo) + +func WithCustomSchemaFactory(factory schema.Factory[schema.Schema]) OptionFn { + return func(a *Amigo) { + a.CustomSchemaFactory = factory + } } // NewAmigo create a new amigo instance -func NewAmigo(ctx *amigoctx.Context) Amigo { - return Amigo{ - ctx: ctx, +func NewAmigo(ctx *amigoconfig.Config, opts ...OptionFn) Amigo { + a := Amigo{ + Config: ctx, Driver: types.GetDriver(ctx.DSN), } + + for _, opt := range opts { + opt(&a) + } + + return a +} + +type MigrationApplier interface { + Apply(direction types.MigrationDirection, version *string, steps *int, migrations []schema.Migration) bool + GetSchema() schema.Schema +} + +func (a Amigo) GetMigrationApplier( + ctx context.Context, + conn *sql.DB, +) (MigrationApplier, error) { + recorder := dblog.NewHandler(a.Config.ShowSQLSyntaxHighlighting) + recorder.ToggleLogger(true) + + if a.Config.ValidateDSN() == nil { + conn = sqldblogger.OpenDriver(a.Config.GetRealDSN(), conn.Driver(), recorder) + } + + if a.CustomSchemaFactory != nil { + return schema.NewMigrator(ctx, conn, a.CustomSchemaFactory, a.Config), nil + } + + switch a.Driver { + case types.DriverPostgres: + return schema.NewMigrator(ctx, conn, pg.NewPostgres, a.Config), nil + case types.DriverSQLite: + return schema.NewMigrator(ctx, conn, sqlite.NewSQLite, a.Config), nil + } + + return schema.NewMigrator(ctx, conn, base.NewBase, a.Config), nil +} + +func (a Amigo) GetSchema(ctx context.Context, conn *sql.DB) (schema.Schema, error) { + applier, err := a.GetMigrationApplier(ctx, conn) + if err != nil { + return nil, err + } + + return applier.GetSchema(), nil } diff --git a/pkg/amigo/dump_schema.go b/pkg/amigo/dump_schema.go index bca35e8..f73b6ae 100644 --- a/pkg/amigo/dump_schema.go +++ b/pkg/amigo/dump_schema.go @@ -13,23 +13,32 @@ import ( ) func (a Amigo) DumpSchema(writer io.Writer, ignoreSchemaVersionTable bool) error { - db, err := schema.ExtractCredentials(a.ctx.GetRealDSN()) + // Config variables declaration + var ( + realDSN = a.Config.GetRealDSN() + pgDumpPath = a.Config.PGDumpPath + schemaVersionTable = a.Config.SchemaVersionTable + schemaDBDumpSchema = a.Config.SchemaToDump + shellPath = a.Config.ShellPath + ) + + db, err := schema.ExtractCredentials(realDSN) if err != nil { return err } - ignoreTableName := a.ctx.SchemaVersionTable + ignoreTableName := schemaVersionTable if strings.Contains(ignoreTableName, ".") { ignoreTableName = strings.Split(ignoreTableName, ".")[1] } args := []string{ - a.ctx.PGDumpPath, + pgDumpPath, "-d", db.DB, "-h", db.Host, "-U", db.User, "-p", db.Port, - "-n", a.ctx.SchemaDBDumpSchema, + "-n", schemaDBDumpSchema, "-s", "-x", "-O", @@ -46,7 +55,7 @@ func (a Amigo) DumpSchema(writer io.Writer, ignoreSchemaVersionTable bool) error env := map[string]string{"PGPASSWORD": db.Pass} - stdout, stderr, err := cmdexec.Exec(a.ctx.ShellPath, []string{"-c", strings.Join(args, " ")}, env) + stdout, stderr, err := cmdexec.Exec(shellPath, []string{"-c", strings.Join(args, " ")}, env) if err != nil { return fmt.Errorf("unable to dump database: %w\n%s", err, stderr) } @@ -64,8 +73,8 @@ func (a Amigo) DumpSchema(writer io.Writer, ignoreSchemaVersionTable bool) error schemaPattern := fmt.Sprintf( `(?s)(.*?)(?:--\s*\n--\s*Name:\s*%s;\s*Type:\s*SCHEMA.*?\n--\s*\n\s*CREATE\s+SCHEMA\s+%s;\s*\n)(.*)`, - regexp.QuoteMeta(a.ctx.SchemaDBDumpSchema), - regexp.QuoteMeta(a.ctx.SchemaDBDumpSchema), + regexp.QuoteMeta(schemaDBDumpSchema), + regexp.QuoteMeta(schemaDBDumpSchema), ) dumpParts := regexp.MustCompile(schemaPattern).FindStringSubmatch(stdout) @@ -73,7 +82,7 @@ func (a Amigo) DumpSchema(writer io.Writer, ignoreSchemaVersionTable bool) error return fmt.Errorf("failed to parse schema dump: unexpected format") } - setSchemaPath := fmt.Sprintf("SET search_path TO %s;\n", a.ctx.SchemaDBDumpSchema) + setSchemaPath := fmt.Sprintf("SET search_path TO %s;\n", schemaDBDumpSchema) // Combine all parts with the proper ordering result := dateGenerated + diff --git a/pkg/amigo/execute_main.go b/pkg/amigo/execute_main.go deleted file mode 100644 index 3ab8294..0000000 --- a/pkg/amigo/execute_main.go +++ /dev/null @@ -1,71 +0,0 @@ -package amigo - -import ( - "encoding/base64" - "encoding/json" - "fmt" - "os" - "path" - "strings" - - "github.com/alexisvisco/amigo/pkg/utils/cmdexec" -) - -type MainArg string - -const ( - MainArgMigrate MainArg = "migrate" - MainArgRollback MainArg = "rollback" - MainArgSkipMigration MainArg = "skip-migration" - MainArgStatus MainArg = "status" -) - -func (m MainArg) Validate() error { - switch m { - case MainArgMigrate, MainArgRollback, MainArgSkipMigration, MainArgStatus: - return nil - } - - return fmt.Errorf("invalid main arg: %s", m) -} - -func (a Amigo) ExecuteMain(arg MainArg) error { - mainFilePath := path.Join(a.ctx.AmigoFolderPath, "main.go") - mainBinaryPath := path.Join(a.ctx.AmigoFolderPath, "main") - _, err := os.Stat(mainFilePath) - if os.IsNotExist(err) { - return fmt.Errorf("%s file does not exist, please run 'amigo init' to create it", mainFilePath) - } - - // build binary - args := []string{ - "go", "build", - "-o", mainBinaryPath, - mainFilePath, - } - - err = cmdexec.ExecToWriter(a.ctx.ShellPath, []string{"-c", strings.Join(args, " ")}, nil, os.Stdout, os.Stderr) - if err != nil { - return err - } - - amigoCtxJson, err := json.Marshal(a.ctx) - if err != nil { - return fmt.Errorf("unable to marshal amigo context: %w", err) - } - - bas64Json := base64.StdEncoding.EncodeToString(amigoCtxJson) - args = []string{ - "./" + mainBinaryPath, - "-json", bas64Json, - } - - args = append(args, string(arg)) - - err = cmdexec.ExecToWriter(a.ctx.ShellPath, []string{"-c", strings.Join(args, " ")}, nil, os.Stdout, os.Stderr) - if err != nil { - return fmt.Errorf("%s throw an error: %w", mainFilePath, err) - } - - return nil -} diff --git a/pkg/amigo/gen_main_file.go b/pkg/amigo/gen_main_file.go index 88bdbfc..b6394ef 100644 --- a/pkg/amigo/gen_main_file.go +++ b/pkg/amigo/gen_main_file.go @@ -9,21 +9,23 @@ import ( "github.com/alexisvisco/amigo/pkg/utils" ) -// GenerateMainFile generate the main.go file in the amigo folder func (a Amigo) GenerateMainFile(writer io.Writer) error { + var ( + migrationFolder = a.Config.MigrationFolder + ) + name, err := utils.GetModuleName() if err != nil { return fmt.Errorf("unable to get module name: %w", err) } - packagePath := path.Join(name, a.ctx.MigrationFolder) + packagePath := path.Join(name, migrationFolder) template, err := templates.GetMainTemplate(templates.MainData{ PackagePath: packagePath, DriverPath: a.Driver.PackagePath(), DriverName: a.Driver.String(), }) - if err != nil { return fmt.Errorf("unable to get main template: %w", err) } diff --git a/pkg/amigo/gen_migration_file.go b/pkg/amigo/gen_migration_file.go index 02830e3..f3735d0 100644 --- a/pkg/amigo/gen_migration_file.go +++ b/pkg/amigo/gen_migration_file.go @@ -23,8 +23,10 @@ type GenerateMigrationFileParams struct { Writer io.Writer } -// GenerateMigrationFile generate a migration file in the migrations folder func (a Amigo) GenerateMigrationFile(params *GenerateMigrationFileParams) error { + var ( + migrationPackageName = a.Config.MigrationPackageName + ) structName := utils.MigrationStructName(params.Now, params.Name) @@ -37,7 +39,7 @@ func (a Amigo) GenerateMigrationFile(params *GenerateMigrationFileParams) error fileContent, err := templates.GetMigrationTemplate(templates.MigrationData{ IsSQL: params.Type == types.MigrationFileTypeSQL, - Package: a.ctx.PackagePath, + Package: migrationPackageName, StructName: structName, Name: flect.Underscore(params.Name), Type: params.Type, diff --git a/pkg/amigo/gen_migrations_file.go b/pkg/amigo/gen_migrations_file.go index 664839a..8fe8434 100644 --- a/pkg/amigo/gen_migrations_file.go +++ b/pkg/amigo/gen_migrations_file.go @@ -29,7 +29,7 @@ func (a Amigo) GenerateMigrationsFiles(writer io.Writer) error { a.Driver.StructName(), migrationFiles[k].fulName, k.Format(time.RFC3339), - a.ctx.Create.SQLSeparator, + a.Config.Create.SQLSeparator, ) migrations = append(migrations, line) @@ -45,7 +45,7 @@ func (a Amigo) GenerateMigrationsFiles(writer io.Writer) error { } content, err := templates.GetMigrationsTemplate(templates.MigrationsData{ - Package: a.ctx.PackagePath, + Package: a.Config.MigrationPackageName, Migrations: migrations, ImportSchemaPackage: mustImportSchemaPackage, }) @@ -72,7 +72,7 @@ func (a Amigo) getMigrationFiles(ascending bool) (map[time.Time]migrationFile, [ migrationFiles := make(map[time.Time]migrationFile) // get the list of structs by the file name - err := filepath.Walk(a.ctx.MigrationFolder, func(path string, info os.FileInfo, err error) error { + err := filepath.Walk(a.Config.MigrationFolder, func(path string, info os.FileInfo, err error) error { if err != nil { return err } diff --git a/pkg/amigo/get_status.go b/pkg/amigo/get_status.go index 00f586b..c37670f 100644 --- a/pkg/amigo/get_status.go +++ b/pkg/amigo/get_status.go @@ -1,26 +1,17 @@ package amigo import ( + "context" "database/sql" "fmt" ) // GetStatus return the state of the database -func (a Amigo) GetStatus(db *sql.DB) ([]string, error) { - rows, err := db.Query("SELECT version FROM " + a.ctx.SchemaVersionTable + " ORDER BY version desc") +func (a Amigo) GetStatus(ctx context.Context, db *sql.DB) ([]string, error) { + schema, err := a.GetSchema(ctx, db) if err != nil { - return nil, fmt.Errorf("unable to get state: %w", err) + return nil, fmt.Errorf("unable to get schema: %w", err) } - var state []string - for rows.Next() { - var id string - err := rows.Scan(&id) - if err != nil { - return nil, fmt.Errorf("unable to scan state: %w", err) - } - state = append(state, id) - } - - return state, nil + return schema.FindAppliedVersions(), nil } diff --git a/pkg/amigo/run_migration.go b/pkg/amigo/run_migration.go index cff6fdb..0a76386 100644 --- a/pkg/amigo/run_migration.go +++ b/pkg/amigo/run_migration.go @@ -10,17 +10,12 @@ import ( "path" "time" - "github.com/alexisvisco/amigo/pkg/amigoctx" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/schema" - "github.com/alexisvisco/amigo/pkg/schema/base" - "github.com/alexisvisco/amigo/pkg/schema/pg" - "github.com/alexisvisco/amigo/pkg/schema/sqlite" "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils" - "github.com/alexisvisco/amigo/pkg/utils/dblog" "github.com/alexisvisco/amigo/pkg/utils/events" "github.com/alexisvisco/amigo/pkg/utils/logger" - sqldblogger "github.com/simukti/sqldb-logger" ) var ( @@ -28,10 +23,6 @@ var ( ErrMigrationFailed = errors.New("migration failed") ) -type migrationApplier interface { - Apply(direction types.MigrationDirection, version *string, steps *int, migrations []schema.Migration) bool -} - type RunMigrationParams struct { DB *sql.DB Direction types.MigrationDirection @@ -53,20 +44,20 @@ func (a Amigo) RunMigrations(params RunMigrationParams) error { originCtx = params.Context } - ctx, cancel := context.WithDeadline(originCtx, time.Now().Add(a.ctx.Migration.Timeout)) + ctx, cancel := context.WithDeadline(originCtx, time.Now().Add(a.Config.Migration.Timeout)) defer cancel() a.SetupSlog(params.LogOutput, params.Logger) - migrator, err := a.getMigrationApplier(ctx, params.DB) + migrator, err := a.GetMigrationApplier(ctx, params.DB) if err != nil { return err } ok := migrator.Apply( params.Direction, - utils.NilOrValue(a.ctx.Migration.Version), - utils.NilOrValue(a.ctx.Migration.Steps), + utils.NilOrValue(a.Config.Migration.Version), + utils.NilOrValue(a.Config.Migration.Steps), params.Migrations, ) @@ -74,8 +65,8 @@ func (a Amigo) RunMigrations(params RunMigrationParams) error { return ErrMigrationFailed } - if a.ctx.Migration.DumpSchemaAfter { - file, err := utils.CreateOrOpenFile(a.ctx.SchemaOutPath) + if a.Config.Migration.DumpSchemaAfter { + file, err := utils.CreateOrOpenFile(a.Config.SchemaOutPath) if err != nil { return fmt.Errorf("unable to open/create file: %w", err) } @@ -87,23 +78,23 @@ func (a Amigo) RunMigrations(params RunMigrationParams) error { return fmt.Errorf("unable to dump schema after migrating: %w", err) } - logger.Info(events.FileModifiedEvent{FileName: path.Join(a.ctx.SchemaOutPath)}) + logger.Info(events.FileModifiedEvent{FileName: path.Join(a.Config.SchemaOutPath)}) } return nil } func (a Amigo) validateRunMigration(conn *sql.DB, direction *types.MigrationDirection) error { - if a.ctx.SchemaVersionTable == "" { - a.ctx.SchemaVersionTable = amigoctx.DefaultSchemaVersionTable + if a.Config.SchemaVersionTable == "" { + a.Config.SchemaVersionTable = amigoconfig.DefaultSchemaVersionTable } if direction == nil || *direction == "" { *direction = types.MigrationDirectionUp } - if a.ctx.Migration.Timeout == 0 { - a.ctx.Migration.Timeout = amigoctx.DefaultTimeout + if a.Config.Migration.Timeout == 0 { + a.Config.Migration.Timeout = amigoconfig.DefaultTimeout } if conn == nil { @@ -112,33 +103,3 @@ func (a Amigo) validateRunMigration(conn *sql.DB, direction *types.MigrationDire return nil } - -func (a Amigo) getMigrationApplier( - ctx context.Context, - conn *sql.DB, -) (migrationApplier, error) { - recorder := dblog.NewHandler(a.ctx.ShowSQLSyntaxHighlighting) - recorder.ToggleLogger(true) - - if a.ctx.ValidateDSN() == nil { - conn = sqldblogger.OpenDriver(a.ctx.GetRealDSN(), conn.Driver(), recorder) - } - - opts := &schema.MigratorOption{ - DryRun: a.ctx.Migration.DryRun, - ContinueOnError: a.ctx.Migration.ContinueOnError, - SchemaVersionTable: schema.TableName(a.ctx.SchemaVersionTable), - DBLogger: recorder, - DumpSchemaFilePath: utils.NilOrValue(a.ctx.SchemaOutPath), - UseSchemaDump: a.ctx.Migration.UseSchemaDump, - } - - switch a.Driver { - case types.DriverPostgres: - return schema.NewMigrator(ctx, conn, pg.NewPostgres, opts), nil - case types.DriverSQLite: - return schema.NewMigrator(ctx, conn, sqlite.NewSQLite, opts), nil - } - - return schema.NewMigrator(ctx, conn, base.NewBase, opts), nil -} diff --git a/pkg/amigo/setup_slog.go b/pkg/amigo/setup_slog.go index 189aaaf..ac314d8 100644 --- a/pkg/amigo/setup_slog.go +++ b/pkg/amigo/setup_slog.go @@ -8,7 +8,7 @@ import ( ) func (a Amigo) SetupSlog(writer io.Writer, mayLogger *slog.Logger) { - logger.ShowSQLEvents = a.ctx.ShowSQL + logger.ShowSQLEvents = a.Config.ShowSQL if writer == nil && mayLogger == nil { logger.Logger = slog.New(slog.NewJSONHandler(writer, &slog.HandlerOptions{Level: slog.LevelError})) return @@ -20,11 +20,11 @@ func (a Amigo) SetupSlog(writer io.Writer, mayLogger *slog.Logger) { } level := slog.LevelInfo - if a.ctx.Debug { + if a.Config.Debug { level = slog.LevelDebug } - if a.ctx.JSON { + if a.Config.JSON { logger.Logger = slog.New(slog.NewJSONHandler(writer, &slog.HandlerOptions{Level: level})) } else { logger.Logger = slog.New(logger.NewHandler(writer, &logger.Options{Level: level})) diff --git a/pkg/amigo/skip_migration_file.go b/pkg/amigo/skip_migration_file.go index f9837c1..5c7213e 100644 --- a/pkg/amigo/skip_migration_file.go +++ b/pkg/amigo/skip_migration_file.go @@ -1,15 +1,18 @@ package amigo import ( + "context" "database/sql" "fmt" ) -func (a Amigo) SkipMigrationFile(db *sql.DB) error { - _, err := db.Exec("INSERT INTO "+a.ctx.SchemaVersionTable+" (id) VALUES ($1)", a.ctx.Create.Version) +func (a Amigo) SkipMigrationFile(ctx context.Context, db *sql.DB) error { + schema, err := a.GetSchema(ctx, db) if err != nil { - return fmt.Errorf("unable to skip migration file: %w", err) + return fmt.Errorf("unable to get schema: %w", err) } + schema.AddVersion(a.Config.Create.Version) + return nil } diff --git a/pkg/amigoconfig/cli_context.go b/pkg/amigoconfig/cli_context.go new file mode 100644 index 0000000..e11e884 --- /dev/null +++ b/pkg/amigoconfig/cli_context.go @@ -0,0 +1,123 @@ +package amigoconfig + +import ( + "fmt" + "os" + "time" + + "gopkg.in/yaml.v3" +) + +type YamlConfig struct { + ShellPath string `yaml:"shell-path"` + Debug bool `yaml:"debug"` + ShowSQL bool `yaml:"show-sql"` + ShowSQLSyntaxHighlighting bool `yaml:"show-sql-syntax-highlighting"` + + CurrentContext string `yaml:"current-context"` + Contexts map[string]YamlConfigContext `yaml:"contexts"` +} + +type YamlConfigContext struct { + SchemaVersionTable string `yaml:"schema-version-table"` + MigrationFolder string `yaml:"migration-folder"` + MigrationPackageName string `yaml:"migration-package-name"` + SchemaToDump string `yaml:"schema-to-dump"` + SchemaOutPath string `yaml:"schema-out-path"` + Timeout time.Duration `yaml:"timeout"` + DSN string `yaml:"dsn"` + + PGDumpPath string `yaml:"pg-dump-path"` +} + +var DefaultYamlConfig = YamlConfig{ + ShellPath: DefaultShellPath, + Debug: false, + ShowSQL: true, + ShowSQLSyntaxHighlighting: true, + + CurrentContext: "default", + Contexts: map[string]YamlConfigContext{ + "default": { + SchemaVersionTable: DefaultSchemaVersionTable, + MigrationFolder: DefaultMigrationFolder, + MigrationPackageName: DefaultMigrationPackageName, + SchemaToDump: DefaultSchemaToDump, + SchemaOutPath: DefaultSchemaOutPath, + Timeout: DefaultTimeout, + DSN: "postgres://user:password@host:port/dbname?sslmode=disable", + PGDumpPath: DefaultPGDumpPath, + }, + }, +} + +func LoadYamlConfig(path string) (*YamlConfig, error) { + var config YamlConfig + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("unable to read contexts file: %w", err) + } + + err = yaml.Unmarshal(data, &config) + if err != nil { + return nil, err + } + + return &config, nil +} + +func (c *Config) OverrideWithYamlConfig(yaml *YamlConfig) { + if yaml == nil || len(yaml.Contexts) == 0 { + return + } + + context, ok := yaml.Contexts[yaml.CurrentContext] + if !ok { + return + } + + // Override root config values + if yaml.ShellPath != "" { + c.RootConfig.ShellPath = yaml.ShellPath + } + if yaml.Debug { + c.RootConfig.Debug = yaml.Debug + } + if yaml.ShowSQL { + c.RootConfig.ShowSQL = yaml.ShowSQL + } + if yaml.ShowSQLSyntaxHighlighting { + c.RootConfig.ShowSQLSyntaxHighlighting = yaml.ShowSQLSyntaxHighlighting + } + + // Override per-driver config values + if context.SchemaVersionTable != "" { + c.RootConfig.SchemaVersionTable = context.SchemaVersionTable + } + if context.MigrationFolder != "" { + c.RootConfig.MigrationFolder = context.MigrationFolder + } + if context.MigrationPackageName != "" { + c.RootConfig.MigrationPackageName = context.MigrationPackageName + } + if context.SchemaToDump != "" { + c.RootConfig.SchemaToDump = context.SchemaToDump + } + if context.SchemaOutPath != "" { + c.RootConfig.SchemaOutPath = context.SchemaOutPath + } + if context.PGDumpPath != "" { + c.RootConfig.PGDumpPath = context.PGDumpPath + } + if context.Timeout != 0 { + if c.Migration == nil { + c.Migration = &MigrationConfig{} + } + c.Migration.Timeout = context.Timeout + } + if context.DSN != "" { + c.RootConfig.DSN = context.DSN + } + + return +} diff --git a/pkg/amigoconfig/config.go b/pkg/amigoconfig/config.go new file mode 100644 index 0000000..2eb157c --- /dev/null +++ b/pkg/amigoconfig/config.go @@ -0,0 +1,447 @@ +package amigoconfig + +import ( + "errors" + "fmt" + "regexp" + "strings" + "time" + + "github.com/alexisvisco/amigo/pkg/types" +) + +var ( + ErrDSNEmpty = errors.New("dsn is empty") +) + +var ( + DefaultSchemaVersionTable = "public.mig_schema_versions" + DefaultAmigoFolder = "db" + DefaultMigrationFolder = "db/migrations" + DefaultMigrationPackageName = "migrations" + DefaultShellPath = "/bin/bash" + DefaultPGDumpPath = "pg_dump" + DefaultSchemaOutPath = "db/schema.sql" + DefaultTimeout = 2 * time.Minute + DefaultSchemaToDump = "public" + DefaultCreateMigrationSQLSeparator = "-- migrate:down" +) + +type Config struct { + *RootConfig + + Migration *MigrationConfig + Create *CreateConfig +} + +func NewConfig() *Config { + return &Config{ + RootConfig: &RootConfig{ + SchemaVersionTable: DefaultSchemaVersionTable, + AmigoFolderPath: DefaultAmigoFolder, + MigrationFolder: DefaultMigrationFolder, + MigrationPackageName: DefaultMigrationPackageName, + ShellPath: DefaultShellPath, + PGDumpPath: DefaultPGDumpPath, + SchemaOutPath: DefaultSchemaOutPath, + SchemaToDump: DefaultSchemaToDump, + }, + Migration: &MigrationConfig{ + Timeout: DefaultTimeout, + Steps: 1, + }, + Create: &CreateConfig{ + Type: string(types.MigrationFileTypeClassic), + Dump: false, + SQLSeparator: DefaultCreateMigrationSQLSeparator, + Skip: false, + }, + } +} + +type RootConfig struct { + AmigoFolderPath string + DSN string + JSON bool + ShowSQL bool + ShowSQLSyntaxHighlighting bool + MigrationFolder string + MigrationPackageName string + SchemaVersionTable string + ShellPath string + PGDumpPath string + SchemaOutPath string + SchemaToDump string + Debug bool +} + +func (r *Config) GetRealDSN() string { + switch types.GetDriver(r.RootConfig.DSN) { + case types.DriverSQLite: + return strings.TrimPrefix(r.RootConfig.DSN, "sqlite:") + } + + return r.RootConfig.DSN +} + +func (r *RootConfig) ValidateDSN() error { + if r.DSN == "" { + return ErrDSNEmpty + } + + return nil +} + +type MigrationConfig struct { + Version string + Steps int + DryRun bool + ContinueOnError bool + Timeout time.Duration + UseSchemaDump bool + DumpSchemaAfter bool +} + +func (m *MigrationConfig) ValidateVersion() error { + if m.Version == "" { + return nil + } + + re := regexp.MustCompile(`\d{14}(_\w+)?\.(go|sql)`) + if !re.MatchString(m.Version) { + return fmt.Errorf("version must be in the format: 20240502083700 or 20240502083700_name.go or 20240502083700_name.sql") + } + + return nil +} + +type CreateConfig struct { + Type string + Dump bool + + SQLSeparator string + + Skip bool + + // Version is post setted after the name have been generated from the arg and time + Version string +} + +func (c *CreateConfig) ValidateType() error { + allowedTypes := []string{string(types.MigrationFileTypeClassic), string(types.MigrationFileTypeChange), string(types.MigrationFileTypeSQL)} + + for _, t := range allowedTypes { + if c.Type == t { + return nil + } + } + + return fmt.Errorf("unsupported type, allowed types are: %s", strings.Join(allowedTypes, ", ")) +} + +func MergeConfig(toMerge Config) *Config { + defaultCtx := NewConfig() + + if toMerge.RootConfig != nil { + if toMerge.RootConfig.AmigoFolderPath != "" { + defaultCtx.RootConfig.AmigoFolderPath = toMerge.RootConfig.AmigoFolderPath + } + + if toMerge.RootConfig.DSN != "" { + defaultCtx.RootConfig.DSN = toMerge.RootConfig.DSN + } + + if toMerge.RootConfig.JSON { + defaultCtx.RootConfig.JSON = toMerge.RootConfig.JSON + } + + if toMerge.RootConfig.ShowSQL { + defaultCtx.RootConfig.ShowSQL = toMerge.RootConfig.ShowSQL + } + + if toMerge.RootConfig.MigrationFolder != "" { + defaultCtx.RootConfig.MigrationFolder = toMerge.RootConfig.MigrationFolder + } + + if toMerge.RootConfig.MigrationPackageName != "" { + defaultCtx.RootConfig.MigrationPackageName = toMerge.RootConfig.MigrationPackageName + } + + if toMerge.RootConfig.SchemaVersionTable != "" { + defaultCtx.RootConfig.SchemaVersionTable = toMerge.RootConfig.SchemaVersionTable + } + + if toMerge.RootConfig.ShellPath != "" { + defaultCtx.RootConfig.ShellPath = toMerge.RootConfig.ShellPath + } + + if toMerge.RootConfig.PGDumpPath != "" { + defaultCtx.RootConfig.PGDumpPath = toMerge.RootConfig.PGDumpPath + } + + if toMerge.RootConfig.Debug { + defaultCtx.RootConfig.Debug = toMerge.RootConfig.Debug + } + + if toMerge.RootConfig.SchemaToDump != "" { + defaultCtx.RootConfig.SchemaToDump = toMerge.RootConfig.SchemaToDump + } + + if toMerge.RootConfig.SchemaOutPath != "" { + defaultCtx.RootConfig.SchemaOutPath = toMerge.RootConfig.SchemaOutPath + } + } + + if toMerge.Migration != nil { + if toMerge.Migration.Version != "" { + defaultCtx.Migration.Version = toMerge.Migration.Version + } + + if toMerge.Migration.Steps != 0 { + defaultCtx.Migration.Steps = toMerge.Migration.Steps + } + + if toMerge.Migration.DryRun { + defaultCtx.Migration.DryRun = toMerge.Migration.DryRun + } + + if toMerge.Migration.ContinueOnError { + defaultCtx.Migration.ContinueOnError = toMerge.Migration.ContinueOnError + } + + if toMerge.Migration.Timeout != 0 { + defaultCtx.Migration.Timeout = toMerge.Migration.Timeout + } + + if toMerge.Migration.UseSchemaDump { + defaultCtx.Migration.UseSchemaDump = toMerge.Migration.UseSchemaDump + } + } + + if toMerge.Create != nil { + if toMerge.Create.Type != "" { + defaultCtx.Create.Type = toMerge.Create.Type + } + + if toMerge.Create.Dump { + defaultCtx.Create.Dump = toMerge.Create.Dump + } + + if toMerge.Create.Skip { + defaultCtx.Create.Skip = toMerge.Create.Skip + } + + if toMerge.Create.Version != "" { + defaultCtx.Create.Version = toMerge.Create.Version + } + } + + return defaultCtx +} + +// WithAmigoFolder sets the folder where amigo files are stored +// DefaultYamlConfig is "db" +func (a *Config) WithAmigoFolder(folder string) *Config { + a.RootConfig.AmigoFolderPath = folder + return a +} + +// WithMigrationFolder sets the folder where migration files are stored +// DefaultYamlConfig is "db/migrations" +func (a *Config) WithMigrationFolder(folder string) *Config { + a.RootConfig.MigrationFolder = folder + return a +} + +// WithMigrationPackageName sets the package name where migration files are stored +// DefaultYamlConfig is "migrations" +func (a *Config) WithMigrationPackageName(packageName string) *Config { + a.RootConfig.MigrationPackageName = packageName + return a +} + +// WithSchemaVersionTable sets the table name where the schema version is stored +// DefaultYamlConfig is "public.mig_schema_versions" +func (a *Config) WithSchemaVersionTable(table string) *Config { + a.RootConfig.SchemaVersionTable = table + return a +} + +// WithDSN sets the DSN to use +// To use SQLite, use "sqlite:path/to/file.db" +func (a *Config) WithDSN(dsn string) *Config { + a.RootConfig.DSN = dsn + return a +} + +// WithSchemaOutPath sets if the schema should be dumped before migration +// DefaultYamlConfig is "db/schema.sql" +func (a *Config) WithSchemaOutPath(path string) *Config { + a.RootConfig.SchemaOutPath = path + return a +} + +// WithMigrationDumpSchemaAfterMigrating sets if the schema should be dumped after migration +func (a *Config) WithMigrationDumpSchemaAfterMigrating(dumpSchema bool) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.DumpSchemaAfter = dumpSchema + + return a +} + +// WithShowSQL sets if the SQL should be shown in the output +func (a *Config) WithShowSQL(showSQL bool) *Config { + a.RootConfig.ShowSQL = showSQL + return a +} + +// WithJSON sets if the output should be in JSON +func (a *Config) WithJSON(json bool) *Config { + a.RootConfig.JSON = json + return a +} + +// WithShowSQLSyntaxHighlighting sets if the SQL should be highlighted +func (a *Config) WithShowSQLSyntaxHighlighting(highlight bool) *Config { + a.RootConfig.ShowSQLSyntaxHighlighting = highlight + return a +} + +// WithShellPath sets the path to the shell (used to execute commands like pg_dump) +func (a *Config) WithShellPath(path string) *Config { + a.RootConfig.ShellPath = path + return a +} + +// WithPGDumpPath sets the path to the pg_dump executable +// default is "pg_dump" +func (a *Config) WithPGDumpPath(path string) *Config { + a.RootConfig.PGDumpPath = path + return a +} + +// WithSchemaToDump sets the schema to dump +// default is "public" +func (a *Config) WithSchemaToDump(schema string) *Config { + a.RootConfig.SchemaToDump = schema + return a +} + +// WithDebug sets if the debug mode should be enabled +func (a *Config) WithDebug(debug bool) *Config { + a.RootConfig.Debug = debug + return a +} + +// WithMigrationDryRun sets if the migration should be dry run +// when true, the migration will not be executed, it's wrapped in a transaction +func (a *Config) WithMigrationDryRun(dryRun bool) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.DryRun = dryRun + return a +} + +// WithMigrationContinueOnError sets if the migration should continue on error +func (a *Config) WithMigrationContinueOnError(continueOnError bool) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.ContinueOnError = continueOnError + return a +} + +// WithMigrationTimeout sets the timeout for the migration +func (a *Config) WithMigrationTimeout(timeout time.Duration) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.Timeout = timeout + return a +} + +// WithMigrationVersion sets the version of the migration +// format be like: 20240502083700 or 20240502083700_name.{go, sql} +func (a *Config) WithMigrationVersion(version string) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.Version = version + return a +} + +// WithMigrationSteps sets the number of steps to migrate +// useful for rolling back +func (a *Config) WithMigrationSteps(steps int) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.Steps = steps + return a +} + +// WithMigrationUseSchemaDump sets if the schema should be dumped before migration +// when true, if there is no migrations and a schema dump exists, it will be used instead of applying all migrations +func (a *Config) WithMigrationUseSchemaDump(useSchemaDump bool) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.UseSchemaDump = useSchemaDump + return a +} + +// WithMigrationDumpSchemaAfter sets if the schema should be dumped after migration +func (a *Config) WithMigrationDumpSchemaAfter(dumpSchemaAfter bool) *Config { + if a.Migration == nil { + a.Migration = &MigrationConfig{} + } + a.Migration.DumpSchemaAfter = dumpSchemaAfter + return a +} + +// WithCreateType sets the type of the migration file +func (a *Config) WithCreateType(createType types.MigrationFileType) *Config { + if a.Create == nil { + a.Create = &CreateConfig{} + } + a.Create.Type = string(createType) + return a +} + +// WithCreateDump sets if the created file should contains the dump of the database +func (a *Config) WithCreateDump(dump bool) *Config { + if a.Create == nil { + a.Create = &CreateConfig{} + } + a.Create.Dump = dump + return a +} + +// WithCreateSQLSeparator sets the separator to split the down part of the migration in type sql +// DefaultYamlConfig value is "-- migrate:down" +func (a *Config) WithCreateSQLSeparator(separator string) *Config { + if a.Create == nil { + a.Create = &CreateConfig{} + } + a.Create.SQLSeparator = separator + return a +} + +func (a *Config) WithCreateSkip(skip bool) *Config { + if a.Create == nil { + a.Create = &CreateConfig{} + } + a.Create.Skip = skip + return a +} + +func (a *Config) WithCreateVersion(version string) *Config { + if a.Create == nil { + a.Create = &CreateConfig{} + } + a.Create.Version = version + return a +} diff --git a/pkg/amigoctx/ctx.go b/pkg/amigoctx/ctx.go deleted file mode 100644 index bf8086e..0000000 --- a/pkg/amigoctx/ctx.go +++ /dev/null @@ -1,282 +0,0 @@ -package amigoctx - -import ( - "errors" - "fmt" - "regexp" - "strings" - "time" - - "github.com/alexisvisco/amigo/pkg/types" -) - -var ( - ErrDSNEmpty = errors.New("dsn is empty") -) - -var ( - DefaultSchemaVersionTable = "public.mig_schema_versions" - DefaultAmigoFolder = "db" - DefaultMigrationFolder = "db/migrations" - DefaultPackagePath = "migrations" - DefaultShellPath = "/bin/bash" - DefaultPGDumpPath = "pg_dump" - DefaultSchemaOutPath = "db/schema.sql" - DefaultTimeout = 2 * time.Minute - DefaultDBDumpSchema = "public" -) - -type Context struct { - *Root - - Migration *Migration - Create *Create -} - -func NewContext() *Context { - return &Context{ - Root: &Root{ - SchemaVersionTable: DefaultSchemaVersionTable, - AmigoFolderPath: DefaultAmigoFolder, - MigrationFolder: DefaultMigrationFolder, - PackagePath: DefaultPackagePath, - ShellPath: DefaultShellPath, - PGDumpPath: DefaultPGDumpPath, - SchemaOutPath: DefaultSchemaOutPath, - SchemaDBDumpSchema: DefaultDBDumpSchema, - }, - Migration: &Migration{ - Timeout: DefaultTimeout, - Steps: 1, - }, - Create: &Create{}, - } -} - -type Root struct { - AmigoFolderPath string - DSN string - JSON bool - ShowSQL bool - ShowSQLSyntaxHighlighting bool - MigrationFolder string - PackagePath string - SchemaVersionTable string - ShellPath string - PGDumpPath string - SchemaOutPath string - SchemaDBDumpSchema string - Debug bool -} - -func (r *Context) GetRealDSN() string { - switch types.GetDriver(r.Root.DSN) { - case types.DriverSQLite: - return strings.TrimPrefix(r.Root.DSN, "sqlite:") - } - - return r.Root.DSN -} - -func (a *Context) WithAmigoFolder(folder string) *Context { - a.Root.AmigoFolderPath = folder - return a -} - -func (a *Context) WithMigrationFolder(folder string) *Context { - a.Root.MigrationFolder = folder - return a -} - -func (a *Context) WithPackagePath(packagePath string) *Context { - a.Root.PackagePath = packagePath - return a -} - -func (a *Context) WithSchemaVersionTable(table string) *Context { - a.Root.SchemaVersionTable = table - return a -} - -func (a *Context) WithDSN(dsn string) *Context { - a.Root.DSN = dsn - return a -} - -func (a *Context) WithVersion(version string) *Context { - a.Migration.Version = version - return a -} - -func (a *Context) WithDumpSchemaAfterMigrating(dumpSchema bool) *Context { - if a.Migration == nil { - a.Migration = &Migration{} - } - a.Migration.DumpSchemaAfter = dumpSchema - - return a -} - -func (a *Context) WithSteps(steps int) *Context { - a.Migration.Steps = steps - return a -} - -func (a *Context) WithShowSQL(showSQL bool) *Context { - a.Root.ShowSQL = showSQL - return a -} - -func (r *Root) ValidateDSN() error { - if r.DSN == "" { - return ErrDSNEmpty - } - - return nil -} - -type Migration struct { - Version string - Steps int - DryRun bool - ContinueOnError bool - Timeout time.Duration - UseSchemaDump bool - DumpSchemaAfter bool -} - -func (m *Migration) ValidateVersion() error { - if m.Version == "" { - return nil - } - - re := regexp.MustCompile(`\d{14}(_\w+)?\.go`) - if !re.MatchString(m.Version) { - return fmt.Errorf("version must be in the format: 20240502083700 or 20240502083700_name.go") - } - - return nil -} - -type Create struct { - Type string - Dump bool - - SQLSeparator string - - Skip bool - // Version is post setted after the name have been generated from the arg and time - Version string -} - -func (c *Create) ValidateType() error { - allowedTypes := []string{string(types.MigrationFileTypeClassic), string(types.MigrationFileTypeChange), string(types.MigrationFileTypeSQL)} - - for _, t := range allowedTypes { - if c.Type == t { - return nil - } - } - - return fmt.Errorf("unsupported type, allowed types are: %s", strings.Join(allowedTypes, ", ")) -} - -func MergeContext(toMerge Context) *Context { - defaultCtx := NewContext() - - if toMerge.Root != nil { - if toMerge.Root.AmigoFolderPath != "" { - defaultCtx.Root.AmigoFolderPath = toMerge.Root.AmigoFolderPath - } - - if toMerge.Root.DSN != "" { - defaultCtx.Root.DSN = toMerge.Root.DSN - } - - if toMerge.Root.JSON { - defaultCtx.Root.JSON = toMerge.Root.JSON - } - - if toMerge.Root.ShowSQL { - defaultCtx.Root.ShowSQL = toMerge.Root.ShowSQL - } - - if toMerge.Root.MigrationFolder != "" { - defaultCtx.Root.MigrationFolder = toMerge.Root.MigrationFolder - } - - if toMerge.Root.PackagePath != "" { - defaultCtx.Root.PackagePath = toMerge.Root.PackagePath - } - - if toMerge.Root.SchemaVersionTable != "" { - defaultCtx.Root.SchemaVersionTable = toMerge.Root.SchemaVersionTable - } - - if toMerge.Root.ShellPath != "" { - defaultCtx.Root.ShellPath = toMerge.Root.ShellPath - } - - if toMerge.Root.PGDumpPath != "" { - defaultCtx.Root.PGDumpPath = toMerge.Root.PGDumpPath - } - - if toMerge.Root.Debug { - defaultCtx.Root.Debug = toMerge.Root.Debug - } - - if toMerge.Root.SchemaDBDumpSchema != "" { - defaultCtx.Root.SchemaDBDumpSchema = toMerge.Root.SchemaDBDumpSchema - } - - if toMerge.Root.SchemaOutPath != "" { - defaultCtx.Root.SchemaOutPath = toMerge.Root.SchemaOutPath - } - } - - if toMerge.Migration != nil { - if toMerge.Migration.Version != "" { - defaultCtx.Migration.Version = toMerge.Migration.Version - } - - if toMerge.Migration.Steps != 0 { - defaultCtx.Migration.Steps = toMerge.Migration.Steps - } - - if toMerge.Migration.DryRun { - defaultCtx.Migration.DryRun = toMerge.Migration.DryRun - } - - if toMerge.Migration.ContinueOnError { - defaultCtx.Migration.ContinueOnError = toMerge.Migration.ContinueOnError - } - - if toMerge.Migration.Timeout != 0 { - defaultCtx.Migration.Timeout = toMerge.Migration.Timeout - } - - if toMerge.Migration.UseSchemaDump { - defaultCtx.Migration.UseSchemaDump = toMerge.Migration.UseSchemaDump - } - } - - if toMerge.Create != nil { - if toMerge.Create.Type != "" { - defaultCtx.Create.Type = toMerge.Create.Type - } - - if toMerge.Create.Dump { - defaultCtx.Create.Dump = toMerge.Create.Dump - } - - if toMerge.Create.Skip { - defaultCtx.Create.Skip = toMerge.Create.Skip - } - - if toMerge.Create.Version != "" { - defaultCtx.Create.Version = toMerge.Create.Version - } - } - - return defaultCtx -} diff --git a/pkg/entrypoint/context.go b/pkg/entrypoint/context.go new file mode 100644 index 0000000..5cf1b18 --- /dev/null +++ b/pkg/entrypoint/context.go @@ -0,0 +1,97 @@ +package entrypoint + +import ( + "fmt" + "path/filepath" + + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/amigoconfig" + "github.com/alexisvisco/amigo/pkg/utils" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" + "github.com/spf13/cobra" + "gopkg.in/yaml.v3" +) + +const contextsFileName = "contexts.yml" + +// contextCmd represents the context command +var contextCmd = &cobra.Command{ + Use: "context", + Short: "show the current context yaml file", + Long: `A context is a file inside the amigo folder that contains the flags that you use in the command line. + +Example: + amigo context --dsn "postgres://user:password@host:port/dbname?sslmode=disable" + +This command will create a file $amigo_folder/context.yaml with the content: + dsn: "postgres://user:password@host:port/dbname?sslmode=disable" +`, + Run: wrapCobraFunc(func(cmd *cobra.Command, a amigo.Amigo, args []string) error { + content, err := utils.GetFileContent(filepath.Join(a.Config.AmigoFolderPath, contextsFileName)) + if err != nil { + return fmt.Errorf("unable to read contexts file: %w", err) + } + + fmt.Println(string(content)) + + return nil + }), +} + +var ContextSetCmd = &cobra.Command{ + Use: "set", + Short: "Set the current context", + Run: wrapCobraFunc(func(cmd *cobra.Command, a amigo.Amigo, args []string) error { + yamlConfig, err := amigoconfig.LoadYamlConfig(filepath.Join(a.Config.AmigoFolderPath, contextsFileName)) + if err != nil { + return fmt.Errorf("unable to read contexts file: %w", err) + } + + if len(args) == 0 { + return fmt.Errorf("missing context name") + } + + if _, ok := yamlConfig.Contexts[args[0]]; !ok { + return fmt.Errorf("context %s not found", args[0]) + } + + yamlConfig.CurrentContext = args[0] + + file, err := utils.CreateOrOpenFile(filepath.Join(a.Config.AmigoFolderPath, contextsFileName)) + if err != nil { + return fmt.Errorf("unable to open contexts file: %w", err) + } + defer file.Close() + + err = file.Truncate(0) + if err != nil { + return fmt.Errorf("unable to truncate contexts file: %w", err) + } + + _, err = file.Seek(0, 0) + if err != nil { + return fmt.Errorf("unable to seek file: %w", err) + } + + yamlOut, err := yaml.Marshal(yamlConfig) + if err != nil { + return fmt.Errorf("unable to marshal yaml: %w", err) + } + + _, err = file.WriteString(string(yamlOut)) + if err != nil { + return fmt.Errorf("unable to write contexts file: %w", err) + } + + logger.Info(events.FileModifiedEvent{FileName: filepath.Join(a.Config.AmigoFolderPath, contextsFileName)}) + logger.Info(events.MessageEvent{Message: "context set to " + args[0]}) + + return nil + }), +} + +func init() { + rootCmd.AddCommand(contextCmd) + contextCmd.AddCommand(ContextSetCmd) +} diff --git a/cmd/create.go b/pkg/entrypoint/create.go similarity index 63% rename from cmd/create.go rename to pkg/entrypoint/create.go index 6cf280a..3f468a2 100644 --- a/cmd/create.go +++ b/pkg/entrypoint/create.go @@ -1,7 +1,8 @@ -package cmd +package entrypoint import ( "bytes" + "context" "fmt" "path" "path/filepath" @@ -25,17 +26,17 @@ var createCmd = &cobra.Command{ return fmt.Errorf("name is required: amigo create ") } - if err := cmdCtx.ValidateDSN(); err != nil { + if err := config.ValidateDSN(); err != nil { return err } - if err := cmdCtx.Create.ValidateType(); err != nil { + if err := config.Create.ValidateType(); err != nil { return err } inUp := "" - if cmdCtx.Create.Dump { + if config.Create.Dump { buffer := &bytes.Buffer{} err := am.DumpSchema(buffer, true) if err != nil { @@ -44,19 +45,19 @@ var createCmd = &cobra.Command{ inUp += fmt.Sprintf("s.Exec(`%s`)\n", buffer.String()) - cmdCtx.Create.Type = "classic" + config.Create.Type = "classic" } now := time.Now() version := now.UTC().Format(utils.FormatTime) - cmdCtx.Create.Version = version + config.Create.Version = version ext := "go" - if cmdCtx.Create.Type == "sql" { + if config.Create.Type == "sql" { ext = "sql" } migrationFileName := fmt.Sprintf("%s_%s.%s", version, flect.Underscore(args[0]), ext) - file, err := utils.CreateOrOpenFile(filepath.Join(cmdCtx.MigrationFolder, migrationFileName)) + file, err := utils.CreateOrOpenFile(filepath.Join(config.MigrationFolder, migrationFileName)) if err != nil { return fmt.Errorf("unable to open/create file: %w", err) } @@ -65,7 +66,7 @@ var createCmd = &cobra.Command{ Name: args[0], Up: inUp, Down: "", - Type: types.MigrationFileType(cmdCtx.Create.Type), + Type: types.MigrationFileType(config.Create.Type), Now: now, Writer: file, }) @@ -73,10 +74,10 @@ var createCmd = &cobra.Command{ return err } - logger.Info(events.FileAddedEvent{FileName: filepath.Join(cmdCtx.MigrationFolder, migrationFileName)}) + logger.Info(events.FileAddedEvent{FileName: filepath.Join(config.MigrationFolder, migrationFileName)}) // create the migrations file where all the migrations will be stored - file, err = utils.CreateOrOpenFile(path.Join(cmdCtx.MigrationFolder, migrationsFile)) + file, err = utils.CreateOrOpenFile(path.Join(config.MigrationFolder, migrationsFile)) if err != nil { return err } @@ -86,15 +87,20 @@ var createCmd = &cobra.Command{ return err } - logger.Info(events.FileModifiedEvent{FileName: path.Join(cmdCtx.MigrationFolder, migrationsFile)}) + logger.Info(events.FileModifiedEvent{FileName: path.Join(config.MigrationFolder, migrationsFile)}) - if cmdCtx.Create.Skip { - err = am.ExecuteMain(amigo.MainArgSkipMigration) + if config.Create.Skip { + db, err := database(*am.Config) if err != nil { - return err + return fmt.Errorf("unable to get database: %w", err) } - logger.Info(events.SkipMigrationEvent{MigrationVersion: version}) + ctx, cancelFunc := context.WithTimeout(context.Background(), am.Config.Migration.Timeout) + defer cancelFunc() + err = am.SkipMigrationFile(ctx, db) + if err != nil { + return err + } } return nil @@ -103,15 +109,15 @@ var createCmd = &cobra.Command{ func init() { rootCmd.AddCommand(createCmd) - createCmd.Flags().StringVar(&cmdCtx.Create.Type, "type", "change", + createCmd.Flags().StringVar(&config.Create.Type, "type", "change", "The type of migration to create, possible values are [classic, change, sql]") - createCmd.Flags().BoolVarP(&cmdCtx.Create.Dump, "dump", "d", false, + createCmd.Flags().BoolVarP(&config.Create.Dump, "dump", "d", false, "dump with pg_dump the current schema and add it to the current migration") - createCmd.Flags().StringVar(&cmdCtx.Create.SQLSeparator, "sql-separator", "-- migrate:down", + createCmd.Flags().StringVar(&config.Create.SQLSeparator, "sql-separator", "-- migrate:down", "the separator to split the up and down part of the migration") - createCmd.Flags().BoolVar(&cmdCtx.Create.Skip, "skip", false, + createCmd.Flags().BoolVar(&config.Create.Skip, "skip", false, "skip will set the migration as applied without executing it") } diff --git a/pkg/entrypoint/main.go b/pkg/entrypoint/main.go index 88911f6..f1183b5 100644 --- a/pkg/entrypoint/main.go +++ b/pkg/entrypoint/main.go @@ -2,131 +2,34 @@ package entrypoint import ( "database/sql" - "encoding/base64" - "encoding/json" - "flag" - "fmt" - "os" - "strings" - "text/tabwriter" "github.com/alexisvisco/amigo/pkg/amigo" - "github.com/alexisvisco/amigo/pkg/amigoctx" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/schema" - "github.com/alexisvisco/amigo/pkg/types" - "github.com/alexisvisco/amigo/pkg/utils" - "github.com/alexisvisco/amigo/pkg/utils/colors" - "github.com/alexisvisco/amigo/pkg/utils/events" - "github.com/alexisvisco/amigo/pkg/utils/logger" ) -func Main(db *sql.DB, arg amigo.MainArg, migrations []schema.Migration, ctx *amigoctx.Context) { - am := amigo.NewAmigo(ctx) - am.SetupSlog(os.Stdout, nil) +type DatabaseProvider func(cfg amigoconfig.Config) (*sql.DB, error) - switch arg { - case amigo.MainArgMigrate, amigo.MainArgRollback: - dir := types.MigrationDirectionUp - if arg == amigo.MainArgRollback { - dir = types.MigrationDirectionDown - } - err := am.RunMigrations(amigo.RunMigrationParams{ - DB: db, - Direction: dir, - Migrations: migrations, - LogOutput: os.Stdout, - }) - - if err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) - } - case amigo.MainArgSkipMigration: - err := am.SkipMigrationFile(db) - if err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) - } - case amigo.MainArgStatus: - versions, err := am.GetStatus(db) - if err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) - } - - hasVersion := func(version string) bool { - for _, v := range versions { - if v == version { - return true - } - } - return false - } - - // show status of 10 last migrations - b := &strings.Builder{} - tw := tabwriter.NewWriter(b, 2, 0, 1, ' ', 0) - - defaultMigrations := sliceArrayOrDefault(migrations, 10) - - for i, m := range defaultMigrations { - - key := fmt.Sprintf("(%s) %s", m.Date().UTC().Format(utils.FormatTime), m.Name()) - value := colors.Red("not applied") - if hasVersion(m.Date().UTC().Format(utils.FormatTime)) { - value = colors.Green("applied") - } - - fmt.Fprintf(tw, "%s\t\t%s", key, value) - if i != len(defaultMigrations)-1 { - fmt.Fprintln(tw) - } - } - - tw.Flush() - logger.Info(events.MessageEvent{Message: b.String()}) - } -} - -func sliceArrayOrDefault[T any](array []T, x int) []T { - defaultMigrations := array - if len(array) >= x { - defaultMigrations = array[len(array)-x:] - } - return defaultMigrations +type MainOptions struct { + CustomAmigo func(a *amigo.Amigo) amigo.Amigo } -func AmigoContextFromFlags() (*amigoctx.Context, amigo.MainArg) { - jsonFlag := flag.String("json", "", "all amigo context in json | bas64") +type MainOptFn func(options *MainOptions) - flag.Parse() +func Main(db DatabaseProvider, migrationsList []schema.Migration, opts ...MainOptFn) { + database = db + migrations = migrationsList - if flag.NArg() == 0 { - logger.Error(events.MessageEvent{Message: "missing argument"}) - os.Exit(1) + options := &MainOptions{} + for _, opt := range opts { + opt(options) } - arg := amigo.MainArg(flag.Arg(0)) - if err := arg.Validate(); err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) - } + _ = rootCmd.Execute() +} - a := amigoctx.NewContext() - if *jsonFlag != "" { - b64decoded, err := base64.StdEncoding.DecodeString(*jsonFlag) - if err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("unable to unmarshal amigo context b64: %s", - err.Error())}) - os.Exit(1) - } - err = json.Unmarshal(b64decoded, a) - if err != nil { - logger.Error(events.MessageEvent{Message: fmt.Sprintf("unable to unmarshal amigo context json: %s", - err.Error())}) - os.Exit(1) - } +func WithCustomAmigo(f func(a *amigo.Amigo) amigo.Amigo) MainOptFn { + return func(options *MainOptions) { + options.CustomAmigo = f } - - return a, arg } diff --git a/pkg/entrypoint/main_test.go b/pkg/entrypoint/main_test.go deleted file mode 100644 index a804525..0000000 --- a/pkg/entrypoint/main_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package entrypoint - -import ( - "github.com/stretchr/testify/require" - "testing" -) - -func Test_sliceArrayOrDefault(t *testing.T) { - type args struct { - iarray []int - x int - } - tests := []struct { - name string - args args - want []int - }{ - { - name: "should return the last 2 migrations", - args: args{ - iarray: []int{1, 2, 3}, - x: 2, - }, - - want: []int{2, 3}, - }, - - { - name: "should return the last 1 migration", - args: args{ - iarray: []int{1, 2, 3}, - x: 1, - }, - want: []int{3}, - }, - - { - name: "should return the last 3 migration if less than x", - args: args{ - iarray: []int{1, 2, 3}, - x: 5, - }, - want: []int{1, 2, 3}, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := sliceArrayOrDefault(tt.args.iarray, tt.args.x) - require.Equal(t, tt.want, got) - }) - } -} diff --git a/pkg/entrypoint/migration.go b/pkg/entrypoint/migration.go new file mode 100644 index 0000000..8203c74 --- /dev/null +++ b/pkg/entrypoint/migration.go @@ -0,0 +1,102 @@ +package entrypoint + +import ( + "context" + "fmt" + "os" + + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/amigoconfig" + "github.com/alexisvisco/amigo/pkg/types" + "github.com/spf13/cobra" +) + +// migrateCmd represents the up command +var migrateCmd = &cobra.Command{ + Use: "migrate", + Short: "Apply the database", + Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { + if err := config.ValidateDSN(); err != nil { + return err + } + + db, err := database(*am.Config) + if err != nil { + return fmt.Errorf("unable to get database: %w", err) + } + + ctx, cancelFunc := context.WithTimeout(context.Background(), am.Config.Migration.Timeout) + defer cancelFunc() + + err = am.RunMigrations(amigo.RunMigrationParams{ + DB: db, + Direction: types.MigrationDirectionUp, + Migrations: migrations, + LogOutput: os.Stdout, + Context: ctx, + }) + + if err != nil { + return fmt.Errorf("failed to migrate database: %w", err) + } + + return nil + }), +} + +// rollbackCmd represents the down command +var rollbackCmd = &cobra.Command{ + Use: "rollback", + Short: "Rollback the database", + Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { + if err := config.ValidateDSN(); err != nil { + return err + } + + db, err := database(*am.Config) + if err != nil { + return fmt.Errorf("unable to get database: %w", err) + } + + ctx, cancelFunc := context.WithTimeout(context.Background(), am.Config.Migration.Timeout) + defer cancelFunc() + + err = am.RunMigrations(amigo.RunMigrationParams{ + DB: db, + Direction: types.MigrationDirectionDown, + Migrations: migrations, + LogOutput: os.Stdout, + Context: ctx, + }) + + if err != nil { + return fmt.Errorf("failed to migrate database: %w", err) + } + + return nil + }), +} + +func init() { + rootCmd.AddCommand(rollbackCmd) + rootCmd.AddCommand(migrateCmd) + + registerBase := func(cmd *cobra.Command, m *amigoconfig.MigrationConfig) { + cmd.Flags().StringVar(&m.Version, "version", "", + "Apply a specific version format: 20240502083700 or 20240502083700_name.go") + cmd.Flags().BoolVar(&m.DryRun, "dry-run", false, "Run the migrations without applying them") + cmd.Flags().BoolVar(&m.ContinueOnError, "continue-on-error", false, + "Will not rollback the migration if an error occurs") + cmd.Flags().DurationVar(&m.Timeout, "timeout", amigoconfig.DefaultTimeout, "The timeout for the migration") + cmd.Flags().BoolVarP(&m.DumpSchemaAfter, "dump-schema-after", "d", false, + "Dump schema after migrate/rollback (not compatible with --use-schema-dump)") + } + + registerBase(migrateCmd, config.Migration) + migrateCmd.Flags().BoolVar(&config.Migration.UseSchemaDump, "use-schema-dump", false, + "Use the schema file to apply the migration (for fresh install without any migration)") + + registerBase(rollbackCmd, config.Migration) + rollbackCmd.Flags().IntVar(&config.Migration.Steps, "steps", 1, "The number of steps to rollback") + +} diff --git a/pkg/entrypoint/root.go b/pkg/entrypoint/root.go new file mode 100644 index 0000000..22372b6 --- /dev/null +++ b/pkg/entrypoint/root.go @@ -0,0 +1,133 @@ +package entrypoint + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/amigoconfig" + "github.com/alexisvisco/amigo/pkg/schema" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" + "github.com/spf13/cobra" +) + +var ( + config = amigoconfig.NewConfig() + database func(cfg amigoconfig.Config) (*sql.DB, error) + migrations []schema.Migration + customAmigoFn func(a *amigo.Amigo) *amigo.Amigo + migrationsFile = "migrations.go" +) + +// rootCmd represents the base command when called without any subcommands +var rootCmd = &cobra.Command{ + Use: "amigo", + Short: "Tool to manage database migrations with go files", + Long: `Basic usage: +First you need to create a main folder with amigo init: + + will create a folder in db with a context file inside to not have to pass the dsn every time. + + Postgres: + $ amigo context --dsn "postgres://user:password@host:port/dbname?sslmode=disable" + + SQLite: + $ amigo context --dsn "sqlite:/path/to/db.sqlite" --schema-version-table mig_schema_versions + + Unknown Driver (Mysql in this case): + $ amigo context --dsn "user:password@tcp(host:port)/dbname" + + + $ amigo init + note: will create: + - folder named migrations with a file named migrations.go that contains the list of migrations + - a new migration to create the schema version table + - a main.go in the $amigo_folder + +Apply migrations: + $ amigo migrate + note: you can set --version to migrate a specific version + +Create a new migration: + $ amigo create "create_table_users" + note: you can set --dump if you already have a database and you want to create the first migration with what's + already in the database. --skip will add the version of the created migration inside the schema version table. + +Rollback a migration: + $ amigo rollback + note: you can set --step to rollback a specific number of migrations, and --version to rollback + to a specific version +`, + SilenceUsage: true, +} + +func init() { + rootCmd.PersistentFlags().StringVarP(&config.AmigoFolderPath, "amigo-folder", "m", + amigoconfig.DefaultAmigoFolder, + "Folder path to use for creating amigo related files related to this repository") + + rootCmd.PersistentFlags().StringVar(&config.DSN, "dsn", "", + "The database connection string example: postgres://user:password@host:port/dbname?sslmode=disable") + + rootCmd.PersistentFlags().BoolVarP(&config.JSON, "json", "j", false, "Output in json format") + + rootCmd.PersistentFlags().StringVar(&config.MigrationFolder, "folder", amigoconfig.DefaultMigrationFolder, + "Folder where the migrations are stored") + + rootCmd.PersistentFlags().StringVarP(&config.MigrationPackageName, "package", "p", + amigoconfig.DefaultMigrationPackageName, + "Package name of the migrations folder") + + rootCmd.PersistentFlags().StringVarP(&config.SchemaVersionTable, "schema-version-table", "t", + amigoconfig.DefaultSchemaVersionTable, "Table name to keep track of the migrations") + + rootCmd.PersistentFlags().StringVar(&config.ShellPath, "shell-path", amigoconfig.DefaultShellPath, + "Shell to use (for: amigo create --dump, it uses pg dump command)") + + rootCmd.PersistentFlags().BoolVar(&config.ShowSQL, "sql", false, "Print SQL queries") + + rootCmd.PersistentFlags().BoolVar(&config.ShowSQLSyntaxHighlighting, "sql-syntax-highlighting", true, + "Print SQL queries with syntax highlighting") + + rootCmd.PersistentFlags().StringVar(&config.SchemaOutPath, "schema-out-path", amigoconfig.DefaultSchemaOutPath, + "File path of the schema dump if any") + + rootCmd.PersistentFlags().StringVar(&config.PGDumpPath, "pg-dump-path", amigoconfig.DefaultPGDumpPath, + "Path to the pg_dump command if --dump is set") + + rootCmd.PersistentFlags().StringVar(&config.SchemaToDump, "schema-to-dump", + amigoconfig.DefaultSchemaToDump, "Schema to use when dumping schema") + + rootCmd.PersistentFlags().BoolVar(&config.Debug, "debug", false, "Print debug information") + + initConfig() +} + +func initConfig() { + yamlConfig, err := amigoconfig.LoadYamlConfig(filepath.Join(config.AmigoFolderPath, contextsFileName)) + if err != nil { + logger.Error(events.MessageEvent{Message: fmt.Sprintf("error: can't read config: %s", err)}) + return + } + + config.OverrideWithYamlConfig(yamlConfig) +} + +func wrapCobraFunc(f func(cmd *cobra.Command, am amigo.Amigo, args []string) error) func(cmd *cobra.Command, args []string) { + return func(cmd *cobra.Command, args []string) { + + am := amigo.NewAmigo(config) + if customAmigoFn != nil { + am = *customAmigoFn(&am) + } + am.SetupSlog(os.Stdout, nil) + + if err := f(cmd, am, args); err != nil { + logger.Error(events.MessageEvent{Message: err.Error()}) + os.Exit(1) + } + } +} diff --git a/cmd/schema.go b/pkg/entrypoint/schema.go similarity index 83% rename from cmd/schema.go rename to pkg/entrypoint/schema.go index 4befaf2..9cddfef 100644 --- a/cmd/schema.go +++ b/pkg/entrypoint/schema.go @@ -1,4 +1,4 @@ -package cmd +package entrypoint import ( "fmt" @@ -18,7 +18,7 @@ var schemaCmd = &cobra.Command{ Supported databases: - postgres with pg_dump`, Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { - if err := cmdCtx.ValidateDSN(); err != nil { + if err := config.ValidateDSN(); err != nil { return err } @@ -27,7 +27,7 @@ Supported databases: } func dumpSchema(am amigo.Amigo) error { - file, err := utils.CreateOrOpenFile(cmdCtx.SchemaOutPath) + file, err := utils.CreateOrOpenFile(config.SchemaOutPath) if err != nil { return fmt.Errorf("unable to open/create file: %w", err) } @@ -39,7 +39,7 @@ func dumpSchema(am amigo.Amigo) error { return fmt.Errorf("unable to dump schema: %w", err) } - logger.Info(events.FileModifiedEvent{FileName: path.Join(cmdCtx.SchemaOutPath)}) + logger.Info(events.FileModifiedEvent{FileName: path.Join(config.SchemaOutPath)}) return nil } diff --git a/pkg/entrypoint/status.go b/pkg/entrypoint/status.go new file mode 100644 index 0000000..50ae3ee --- /dev/null +++ b/pkg/entrypoint/status.go @@ -0,0 +1,79 @@ +package entrypoint + +import ( + "context" + "fmt" + "strings" + "text/tabwriter" + + "github.com/alexisvisco/amigo/pkg/amigo" + "github.com/alexisvisco/amigo/pkg/utils" + "github.com/alexisvisco/amigo/pkg/utils/colors" + "github.com/alexisvisco/amigo/pkg/utils/events" + "github.com/alexisvisco/amigo/pkg/utils/logger" + "github.com/spf13/cobra" +) + +var statusCmd = &cobra.Command{ + Use: "status", + Short: "Status explain the current state of the database.", + Run: wrapCobraFunc(func(cmd *cobra.Command, am amigo.Amigo, args []string) error { + if err := config.ValidateDSN(); err != nil { + return err + } + + db, err := database(*am.Config) + if err != nil { + return fmt.Errorf("unable to get database: %w", err) + } + + ctx, cancelFunc := context.WithTimeout(context.Background(), am.Config.Migration.Timeout) + defer cancelFunc() + + versions, err := am.GetStatus(ctx, db) + if err != nil { + return fmt.Errorf("unable to get status: %w", err) + } + + hasVersion := func(version string) bool { + for _, v := range versions { + if v == version { + return true + } + } + return false + } + + // show status of 10 last migrations + b := &strings.Builder{} + tw := tabwriter.NewWriter(b, 2, 0, 1, ' ', 0) + defaultMigrations := sliceArrayOrDefault(migrations, 10) + for i, m := range defaultMigrations { + key := fmt.Sprintf("(%s) %s", m.Date().UTC().Format(utils.FormatTime), m.Name()) + value := colors.Red("not applied") + if hasVersion(m.Date().UTC().Format(utils.FormatTime)) { + value = colors.Green("applied") + } + fmt.Fprintf(tw, "%s\t\t%s", key, value) + if i != len(defaultMigrations)-1 { + fmt.Fprintln(tw) + } + } + tw.Flush() + logger.Info(events.MessageEvent{Message: b.String()}) + + return nil + }), +} + +func sliceArrayOrDefault[T any](array []T, x int) []T { + defaultMigrations := array + if len(array) >= x { + defaultMigrations = array[len(array)-x:] + } + return defaultMigrations +} + +func init() { + rootCmd.AddCommand(statusCmd) +} diff --git a/pkg/schema/base/base.go b/pkg/schema/base/base.go index 7c31a0e..377eabd 100644 --- a/pkg/schema/base/base.go +++ b/pkg/schema/base/base.go @@ -52,7 +52,7 @@ func (p *Schema) AddVersion(version string) { sql := `INSERT INTO {version_table} (version) VALUES ({version})` replacer := utils.Replacer{ - "version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()), + "version_table": utils.StrFunc(p.Context.Config.SchemaVersionTable), "version": utils.StrFunc(fmt.Sprintf("'%s'", version)), } @@ -68,7 +68,7 @@ func (p *Schema) AddVersion(version string) { func (p Schema) AddVersions(versions []string) { sql := `INSERT INTO {version_table} (version) VALUES {versions}` replacer := utils.Replacer{ - "version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()), + "version_table": utils.StrFunc(p.Context.Config.SchemaVersionTable), "versions": utils.StrFunc(fmt.Sprintf("('%s')", strings.Join(versions, "'), ('"))), } @@ -87,7 +87,7 @@ func (p *Schema) RemoveVersion(version string) { sql := `DELETE FROM {version_table} WHERE version = {version}` replacer := utils.Replacer{ - "version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()), + "version_table": utils.StrFunc(p.Context.Config.SchemaVersionTable), "version": utils.StrFunc(fmt.Sprintf("'%s'", version)), } @@ -105,7 +105,7 @@ func (p *Schema) FindAppliedVersions() []string { sql := `SELECT version FROM {version_table} ORDER BY version ASC` replacer := utils.Replacer{ - "version_table": utils.StrFunc(p.Context.MigratorOptions.SchemaVersionTable.String()), + "version_table": utils.StrFunc(p.Context.Config.SchemaVersionTable), } rows, err := p.TX.QueryContext(p.Context.Context, replacer.Replace(sql)) diff --git a/pkg/schema/detect_migrations.go b/pkg/schema/detect_migrations.go index 8f87725..f70656a 100644 --- a/pkg/schema/detect_migrations.go +++ b/pkg/schema/detect_migrations.go @@ -20,7 +20,7 @@ func (m *Migrator[T]) detectMigrationsToExec( firstRun = true appliedVersions = []string{} } else if err != nil { - m.ctx.RaiseError(err) + m.migratorContext.RaiseError(err) } var versionsToApply []Migration @@ -36,11 +36,11 @@ func (m *Migrator[T]) detectMigrationsToExec( case types.MigrationDirectionUp: if version != nil && *version != "" { if _, ok := versionToMigration[*version]; !ok { - m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) + m.migratorContext.RaiseError(fmt.Errorf("version %s not found", *version)) } if slices.Contains(appliedVersions, *version) { - m.ctx.RaiseError(fmt.Errorf("version %s already applied", *version)) + m.migratorContext.RaiseError(fmt.Errorf("version %s already applied", *version)) } versionsToApply = append(versionsToApply, versionToMigration[*version]) @@ -55,11 +55,11 @@ func (m *Migrator[T]) detectMigrationsToExec( case types.MigrationDirectionDown: if version != nil && *version != "" { if _, ok := versionToMigration[*version]; !ok { - m.ctx.RaiseError(fmt.Errorf("version %s not found", *version)) + m.migratorContext.RaiseError(fmt.Errorf("version %s not found", *version)) } if !slices.Contains(appliedVersions, *version) { - m.ctx.RaiseError(fmt.Errorf("version %s not applied", *version)) + m.migratorContext.RaiseError(fmt.Errorf("version %s not applied", *version)) } versionsToApply = append(versionsToApply, versionToMigration[*version]) diff --git a/pkg/schema/migrator.go b/pkg/schema/migrator.go index 6ffff2a..7fbeb8a 100644 --- a/pkg/schema/migrator.go +++ b/pkg/schema/migrator.go @@ -6,6 +6,7 @@ import ( "reflect" "time" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils" "github.com/alexisvisco/amigo/pkg/utils/dblog" @@ -59,33 +60,37 @@ type Factory[T Schema] func(ctx *MigratorContext, tx DB, db DB) T // Migrator applies the migrations. type Migrator[T Schema] struct { - db DBTX - ctx *MigratorContext + db DBTX + migratorContext *MigratorContext schemaFactory Factory[T] migrations []func(T) } +func (m *Migrator[T]) GetSchema() Schema { + return m.schemaFactory(m.migratorContext, m.db, m.db) +} + // NewMigrator creates a new migrator. func NewMigrator[T Schema]( ctx context.Context, db DBTX, schemaFactory Factory[T], - opts *MigratorOption, + config *amigoconfig.Config, ) *Migrator[T] { return &Migrator[T]{ db: db, schemaFactory: schemaFactory, - ctx: &MigratorContext{ + migratorContext: &MigratorContext{ Context: ctx, - MigratorOptions: opts, + Config: config, MigrationEvents: &MigrationEvents{}, }, } } func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, steps *int, migrations []Migration) bool { - db := m.schemaFactory(m.ctx, m.db, m.db) + db := m.schemaFactory(m.migratorContext, m.db, m.db) migrationsToExecute, firstRun := m.detectMigrationsToExec( db, @@ -100,7 +105,7 @@ func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, return true } - if firstRun && m.ctx.MigratorOptions.UseSchemaDump { + if firstRun && m.migratorContext.Config.Migration.UseSchemaDump { logger.Info(events.MessageEvent{Message: "We detect a fresh installation and applied the schema dump"}) err := m.tryMigrateWithSchemaDump(migrationsToExecute) if err != nil { @@ -160,16 +165,9 @@ func (m *Migrator[T]) Apply(direction types.MigrationDirection, version *string, } func (m *Migrator[T]) NewSchema() T { - return m.schemaFactory(m.ctx, m.db, m.db) -} - -// Options returns a copy of the options. -func (m *Migrator[T]) Options() MigratorOption { - return *m.ctx.MigratorOptions + return m.schemaFactory(m.migratorContext, m.db, m.db) } func (m *Migrator[T]) ToggleDBLog(b bool) { - if m.Options().DBLogger != nil { - m.Options().DBLogger.ToggleLogger(b) - } + // todo: adjust logger } diff --git a/pkg/schema/migrator_context.go b/pkg/schema/migrator_context.go index dbcbbd4..5f54fce 100644 --- a/pkg/schema/migrator_context.go +++ b/pkg/schema/migrator_context.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/types" "github.com/alexisvisco/amigo/pkg/utils/events" "github.com/alexisvisco/amigo/pkg/utils/logger" @@ -15,7 +16,7 @@ import ( type MigratorContext struct { Context context.Context errors error - MigratorOptions *MigratorOption + Config *amigoconfig.Config MigrationEvents *MigrationEvents MigrationDirection types.MigrationDirection @@ -69,11 +70,11 @@ func NewForceStopError(err error) *ForceStopError { func (m *MigratorContext) RaiseError(err error) { m.addError(err) isForceStopError := errors.Is(err, &ForceStopError{}) - if !m.MigratorOptions.ContinueOnError && !isForceStopError { + if !m.Config.Migration.ContinueOnError && !isForceStopError { panic(err) } else { logger.Info(events.MessageEvent{ - Message: fmt.Sprintf("migration error found, continue due to `continue_on_error` option: %s", err.Error()), + Message: fmt.Sprintf("continue due to `continue_on_error` option: %s", err.Error()), }) } } diff --git a/pkg/schema/pg/postgres_test.go b/pkg/schema/pg/postgres_test.go index f9ff18e..7856be1 100644 --- a/pkg/schema/pg/postgres_test.go +++ b/pkg/schema/pg/postgres_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/schema" "github.com/alexisvisco/amigo/pkg/utils" "github.com/alexisvisco/amigo/pkg/utils/dblog" @@ -94,7 +95,7 @@ func initSchema(t *testing.T, name string, number ...int32) (*sql.DB, dblog.Data _, err = conn.ExecContext(context.Background(), fmt.Sprintf("CREATE SCHEMA %s", schemaName)) require.NoError(t, err) - mig := schema.NewMigrator(context.Background(), conn, NewPostgres, &schema.MigratorOption{}) + mig := schema.NewMigrator(context.Background(), conn, NewPostgres, &amigoconfig.Config{}) return conn, recorder, mig, schemaName } @@ -148,7 +149,7 @@ func TestPostgres_Versions(t *testing.T) { versionTable("tst_pg_add_version", p) - p.Context.MigratorOptions.SchemaVersionTable = schema.Table("mig_schema_version", "tst_pg_add_version") + p.Context.Config.SchemaVersionTable = schema.Table("mig_schema_version", "tst_pg_add_version").String() p.AddVersion("v1") versions := p.FindAppliedVersions() diff --git a/pkg/schema/run_migration.go b/pkg/schema/run_migration.go index b5c2db1..6f2f9e9 100644 --- a/pkg/schema/run_migration.go +++ b/pkg/schema/run_migration.go @@ -10,7 +10,7 @@ import ( // run runs the migration. func (m *Migrator[T]) run(migrationType types.MigrationDirection, version string, f func(T)) (ok bool) { - currentContext := m.ctx + currentContext := m.migratorContext currentContext.MigrationDirection = migrationType tx, err := m.db.BeginTx(currentContext.Context, nil) @@ -49,7 +49,7 @@ func (m *Migrator[T]) run(migrationType types.MigrationDirection, version string schema.RemoveVersion(version) } - if m.ctx.MigratorOptions.DryRun { + if m.migratorContext.Config.Migration.DryRun { logger.Info(events.MessageEvent{Message: "migration in dry run mode, rollback transaction..."}) err := tx.Rollback() if err != nil { diff --git a/pkg/schema/run_migration_schema_dump.go b/pkg/schema/run_migration_schema_dump.go index 3e17b62..58e10bf 100644 --- a/pkg/schema/run_migration_schema_dump.go +++ b/pkg/schema/run_migration_schema_dump.go @@ -13,33 +13,33 @@ import ( // this might be executed when the user arrives on a repo with a schema.sql, instead of running // all the migrations we will try to dump the schema and apply it. Then tell we applied all versions. func (m *Migrator[T]) tryMigrateWithSchemaDump(migrations []Migration) error { - if m.ctx.MigratorOptions.DumpSchemaFilePath == nil { + if m.migratorContext.Config.PGDumpPath == "" { return errors.New("no schema dump file path provided") } - file, err := os.ReadFile(*m.ctx.MigratorOptions.DumpSchemaFilePath) + file, err := os.ReadFile(m.migratorContext.Config.PGDumpPath) if err != nil { return fmt.Errorf("unable to read schema dump file: %w", err) } logger.ShowSQLEvents = false - tx, err := m.db.BeginTx(m.ctx.Context, nil) + tx, err := m.db.BeginTx(m.migratorContext.Context, nil) if err != nil { return fmt.Errorf("unable to start transaction: %w", err) } defer tx.Rollback() - tx.ExecContext(m.ctx.Context, "SET search_path TO public") - _, err = tx.ExecContext(m.ctx.Context, string(file)) + tx.ExecContext(m.migratorContext.Context, "SET search_path TO public") + _, err = tx.ExecContext(m.migratorContext.Context, string(file)) if err != nil { return fmt.Errorf("unable to apply schema dump: %w", err) } tx.Commit() - schema := m.NewSchema() + schema := m.GetSchema() versions := make([]string, 0, len(migrations)) for _, migration := range migrations { diff --git a/pkg/schema/sqlite/sqlite_test.go b/pkg/schema/sqlite/sqlite_test.go index f7b5f0f..11f63b6 100644 --- a/pkg/schema/sqlite/sqlite_test.go +++ b/pkg/schema/sqlite/sqlite_test.go @@ -3,16 +3,18 @@ package sqlite import ( "context" "database/sql" + "log/slog" + "os" + "path" + "testing" + + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/schema" "github.com/alexisvisco/amigo/pkg/utils/dblog" "github.com/alexisvisco/amigo/pkg/utils/logger" _ "github.com/mattn/go-sqlite3" sqldblogger "github.com/simukti/sqldb-logger" "github.com/stretchr/testify/require" - "log/slog" - "os" - "path" - "testing" ) func connect(t *testing.T) (*sql.DB, dblog.DatabaseLogger) { @@ -40,7 +42,7 @@ func connect(t *testing.T) (*sql.DB, dblog.DatabaseLogger) { func baseTest(t *testing.T, init string) (postgres *Schema, rec dblog.DatabaseLogger) { db, rec := connect(t) - m := schema.NewMigrator(context.Background(), db, NewSQLite, &schema.MigratorOption{}) + m := schema.NewMigrator(context.Background(), db, NewSQLite, &amigoconfig.Config{}) if init != "" { _, err := db.ExecContext(context.Background(), init) diff --git a/pkg/templates/init_create_table_base.go.tmpl b/pkg/templates/init_create_table_base.go.tmpl index c1f5df1..f7dc33e 100644 --- a/pkg/templates/init_create_table_base.go.tmpl +++ b/pkg/templates/init_create_table_base.go.tmpl @@ -1,2 +1 @@ -s.Exec(`CREATE TABLE IF NOT EXISTS {{ .Name }} ( "version" VARCHAR(255) NOT NULL PRIMARY KEY )`) -} \ No newline at end of file +s.Exec(`CREATE TABLE IF NOT EXISTS {{ .Name }} ( "version" VARCHAR(255) NOT NULL PRIMARY KEY )`) \ No newline at end of file diff --git a/pkg/templates/main.go.tmpl b/pkg/templates/main.go.tmpl index 27a38ec..86dbe5f 100644 --- a/pkg/templates/main.go.tmpl +++ b/pkg/templates/main.go.tmpl @@ -3,22 +3,17 @@ package main import ( "database/sql" - migrations "{{ .PackagePath }}" + + "github.com/alexisvisco/amigo/pkg/amigoconfig" "github.com/alexisvisco/amigo/pkg/entrypoint" - "github.com/alexisvisco/amigo/pkg/utils/events" - "github.com/alexisvisco/amigo/pkg/utils/logger" - _ "{{ .DriverPath }}" - "os" + migrations "{{ .PackagePath }}" + _ "github.com/jackc/pgx/v5/stdlib" ) func main() { - opts, arg := entrypoint.AmigoContextFromFlags() - - db, err := sql.Open("{{ .DriverName }}", opts.GetRealDSN()) - if err != nil { - logger.Error(events.MessageEvent{Message: err.Error()}) - os.Exit(1) + databaseProvider := func(cfg amigoconfig.Config) (*sql.DB, error) { + return sql.Open("pgx", cfg.GetRealDSN()) } - entrypoint.Main(db, arg, migrations.Migrations, opts) + entrypoint.Main(databaseProvider, migrations.Migrations) } diff --git a/pkg/templates/migrations.go.tmpl b/pkg/templates/migrations.go.tmpl index 26ecff1..5fccd86 100644 --- a/pkg/templates/migrations.go.tmpl +++ b/pkg/templates/migrations.go.tmpl @@ -4,8 +4,8 @@ package {{ .Package }} import ( "github.com/alexisvisco/amigo/pkg/schema" - "embed" - {{if .ImportSchemaPackage}}"{{ .ImportSchemaPackage }}"{{end}} +{{if .ImportSchemaPackage}} "embed" + "{{ .ImportSchemaPackage }}"{{end}} ) {{if .ImportSchemaPackage}} //go:embed *.sql diff --git a/pkg/utils/files.go b/pkg/utils/files.go index d9845ce..9f6cc1c 100644 --- a/pkg/utils/files.go +++ b/pkg/utils/files.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "io" "os" "path/filepath" ) @@ -17,6 +18,21 @@ func CreateOrOpenFile(path string) (*os.File, error) { return os.OpenFile(path, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0644) } +func GetFileContent(path string) ([]byte, error) { + file, err := os.Open(path) + if err != nil { + return nil, err + } + defer file.Close() + + content, err := io.ReadAll(file) + if err != nil { + return nil, err + } + + return content, nil +} + func EnsurePrentDirExists(path string) error { if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { if !os.IsExist(err) {