diff --git a/pkg/migrations/duplicate.go b/pkg/migrations/duplicate.go index d35b3f80..39c0af4f 100644 --- a/pkg/migrations/duplicate.go +++ b/pkg/migrations/duplicate.go @@ -42,14 +42,22 @@ func (d *Duplicator) WithType(t string) *Duplicator { func (d *Duplicator) Duplicate(ctx context.Context) error { const ( cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s` + cSetDefaultSQL = `ALTER COLUMN %s SET DEFAULT %s` cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)` ) + // Generate SQL to duplicate the column's name and type sql := fmt.Sprintf(cAlterTableSQL, pq.QuoteIdentifier(d.table.Name), pq.QuoteIdentifier(d.asName), d.withType) + // Generate SQL to duplicate the column's default value + if d.column.Default != nil { + sql += fmt.Sprintf(", "+cSetDefaultSQL, d.asName, *d.column.Default) + } + + // Generate SQL to duplicate any foreign key constraints on the column for _, fk := range d.table.ForeignKeys { if slices.Contains(fk.Columns, d.column.Name) { sql += fmt.Sprintf(", "+cAddForeignKeySQL, diff --git a/pkg/migrations/op_change_type_test.go b/pkg/migrations/op_change_type_test.go index b8208048..cea81bbc 100644 --- a/pkg/migrations/op_change_type_test.go +++ b/pkg/migrations/op_change_type_test.go @@ -222,6 +222,71 @@ func TestChangeColumnType(t *testing.T) { ConstraintMustExist(t, db, "public", "employees", "fk_employee_department") }, }, + { + name: "changing column type preserves any defaults on the column", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "users", + Columns: []migrations.Column{ + { + Name: "id", + Type: "integer", + Pk: true, + }, + { + Name: "username", + Type: "text", + Default: ptr("'alice'"), + Nullable: true, + }, + }, + }, + }, + }, + { + Name: "02_change_type", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "users", + Column: "username", + Type: "varchar(255)", + Up: "username", + Down: "username", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_change_type", "users", map[string]string{ + "id": "1", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_change_type", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "alice"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_change_type", "users", map[string]string{ + "id": "2", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_change_type", "users") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "alice"}, + {"id": 2, "username": "alice"}, + }, rows) + }, + }, }) }