diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index 164511b8..20025a74 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -496,6 +496,36 @@ func insert(t *testing.T, db *sql.DB, schema, version, table string, record map[ return err } +func MustDelete(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) { + t.Helper() + + if err := delete(t, db, schema, version, table, record); err != nil { + t.Fatal(err) + } +} + +func delete(t *testing.T, db *sql.DB, schema, version, table string, record map[string]string) error { + t.Helper() + versionSchema := roll.VersionedSchemaName(schema, version) + + cols := maps.Keys(record) + slices.Sort(cols) + + recordStr := "" + for i, c := range cols { + if i > 0 { + recordStr += " AND " + } + recordStr += fmt.Sprintf("%s = '%s'", c, record[c]) + } + + //nolint:gosec // this is a test so we don't care about SQL injection + stmt := fmt.Sprintf("DELETE FROM %s.%s WHERE %s", versionSchema, table, recordStr) + + _, err := db.Exec(stmt) + return err +} + func MustSelect(t *testing.T, db *sql.DB, schema, version, table string) []map[string]any { t.Helper() versionSchema := roll.VersionedSchemaName(schema, version) diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index 76a3f145..3b61494f 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -26,7 +26,8 @@ func (o *OpSetUnique) Start(ctx context.Context, conn *sql.DB, stateSchema strin column := table.GetColumn(o.Column) // create a copy of the column on the underlying table. - if err := duplicateColumn(ctx, conn, table, *column); err != nil { + d := NewColumnDuplicator(conn, table, column) + if err := d.Duplicate(ctx); err != nil { return fmt.Errorf("failed to duplicate column: %w", err) } @@ -115,10 +116,11 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn *sql.DB, s *schema.Sche } // Rename the new column to the old column name - _, err = conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", - pq.QuoteIdentifier(o.Table), - pq.QuoteIdentifier(TemporaryName(o.Column)), - pq.QuoteIdentifier(o.Column))) + table := s.GetTable(o.Table) + column := table.GetColumn(o.Column) + if err := RenameDuplicatedColumn(ctx, conn, table, column); err != nil { + return err + } return err } diff --git a/pkg/migrations/op_set_unique_test.go b/pkg/migrations/op_set_unique_test.go index 4d483fd3..d660c0f1 100644 --- a/pkg/migrations/op_set_unique_test.go +++ b/pkg/migrations/op_set_unique_test.go @@ -116,7 +116,7 @@ func TestSetColumnUnique(t *testing.T) { }, }, { - name: "set unique with default user supplied down sql", + name: "set unique with user supplied down sql", migrations: []migrations.Migration{ { Name: "01_add_table", @@ -179,5 +179,164 @@ func TestSetColumnUnique(t *testing.T) { afterComplete: func(t *testing.T, db *sql.DB) { }, }, + { + name: "column defaults are preserved when adding a unique constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "reviews", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "username", + Type: "text", + Default: ptr("'anonymous'"), + }, + { + Name: "product", + Type: "text", + }, + { + Name: "review", + Type: "text", + }, + }, + }, + }, + }, + { + Name: "02_set_unique", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "reviews", + Column: "username", + Unique: &migrations.UniqueConstraint{ + Name: "reviews_username_unique", + }, + Up: "username", + Down: "username", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "product": "apple", "review": "awesome", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_set_unique", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 1, "username": "anonymous", "product": "apple", "review": "awesome"}, + }, rows) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // Delete the row that was inserted in the `afterStart` hook to + // ensure that another row with a default 'username' can be inserted + // without violating the UNIQUE constraint on the column. + MustDelete(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "id": "1", + }) + + // A row can be inserted into the new version of the table. + MustInsert(t, db, "public", "02_set_unique", "reviews", map[string]string{ + "product": "banana", "review": "bent", + }) + + // The newly inserted row respects the default value of the column. + rows := MustSelect(t, db, "public", "02_set_unique", "reviews") + assert.Equal(t, []map[string]any{ + {"id": 2, "username": "anonymous", "product": "banana", "review": "bent"}, + }, rows) + }, + }, + { + name: "foreign keys defined on the column are preserved when adding a unique constraint", + migrations: []migrations.Migration{ + { + Name: "01_add_departments_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "departments", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, + }, + }, + }, + }, + { + Name: "02_add_employees_table", + Operations: migrations.Operations{ + &migrations.OpCreateTable{ + Name: "employees", + Columns: []migrations.Column{ + { + Name: "id", + Type: "serial", + Pk: true, + }, + { + Name: "name", + Type: "text", + Nullable: false, + }, + { + Name: "department_id", + Type: "integer", + Nullable: true, + References: &migrations.ForeignKeyReference{ + Name: "fk_employee_department", + Table: "departments", + Column: "id", + }, + }, + }, + }, + }, + }, + { + Name: "03_set_unique", + Operations: migrations.Operations{ + &migrations.OpAlterColumn{ + Table: "employees", + Column: "department_id", + Unique: &migrations.UniqueConstraint{ + Name: "employees_department_id_unique", + }, + Up: "department_id", + Down: "department_id", + }, + }, + }, + }, + afterStart: func(t *testing.T, db *sql.DB) { + // A temporary FK constraint has been created on the temporary column + ConstraintMustExist(t, db, "public", "employees", migrations.TemporaryName("fk_employee_department")) + }, + afterRollback: func(t *testing.T, db *sql.DB) { + }, + afterComplete: func(t *testing.T, db *sql.DB) { + // The foreign key constraint still exists on the column + ConstraintMustExist(t, db, "public", "employees", "fk_employee_department") + }, + }, }) }