diff --git a/internal/database/db.go b/internal/database/db.go index 9dde50d83..106cf977e 100644 --- a/internal/database/db.go +++ b/internal/database/db.go @@ -14,9 +14,9 @@ type Database interface { GetRoleByID(ctx context.Context, id gidx.PrefixedID) (*Role, error) GetResourceRoleByName(ctx context.Context, resourceID gidx.PrefixedID, name string) (*Role, error) ListResourceRoles(ctx context.Context, resourceID gidx.PrefixedID) ([]*Role, error) - CreateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Role, error) - UpdateRole(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Role, error) - DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (*Role, error) + CreateRoleTransaction(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Transaction[*Role], error) + UpdateRoleTransaction(ctx context.Context, actorID gidx.PrefixedID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Transaction[*Role], error) + DeleteRoleTransaction(ctx context.Context, roleID gidx.PrefixedID) (*Transaction[*Role], error) HealthCheck(ctx context.Context) error } diff --git a/internal/database/roles.go b/internal/database/roles.go index d9e837100..93e69e6f9 100644 --- a/internal/database/roles.go +++ b/internal/database/roles.go @@ -8,9 +8,11 @@ import ( "time" "go.infratographer.com/x/gidx" - "go.uber.org/zap" ) +// TxRole defines a Role Transaction. +type TxRole = *Transaction[*Role] + // Role represents a role in the database. type Role struct { ID gidx.PrefixedID @@ -19,37 +21,6 @@ type Role struct { CreatorID gidx.PrefixedID CreatedAt time.Time UpdatedAt time.Time - - logger *zap.SugaredLogger - commit func() error - rollback func() error -} - -// Commit calls commit on the transaction if the role has been created within a transaction. -// If not the method returns an ErrMethodUnavailable error. -func (r *Role) Commit() error { - if r.commit == nil { - return ErrMethodUnavailable - } - - return r.commit() -} - -// Rollback calls rollback on the transaction if the role has been created within a transaction. -// If not the method returns an ErrMethodUnavailable error. -// -// To simplify rollbacks, logging has automatically been setup to log any errors produced if a rollback fails. -func (r *Role) Rollback() error { - if r.rollback == nil { - return ErrMethodUnavailable - } - - err := r.rollback() - if err != nil && !errors.Is(err, sql.ErrTxDone) { - r.logger.Errorw("failed to rollback role", "role_id", r.ID, zap.Error(err)) - } - - return err } // GetRoleByID retrieves a role from the database by the provided prefixed ID. @@ -165,10 +136,13 @@ func (db *database) ListResourceRoles(ctx context.Context, resourceID gidx.Prefi return roles, nil } -// CreateRole creates a role with the provided details returning the new Role entry. +// CreateRoleTransaction creates a role with the provided details in a new transaction which must be committed. // If a role already exists with the given roleID an ErrRoleAlreadyExists error is returned. // If a role already exists with the same name under the given resource ID then an ErrRoleNameTaken error is returned. -func (db *database) CreateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Role, error) { +// +// Transaction.Commit() or Transaction.Rollback() should be called if error is nil otherwise the database will hold +// the indexes waiting for the transaction to complete. +func (db *database) CreateRoleTransaction(ctx context.Context, actorID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (TxRole, error) { var role Role tx, err := db.BeginTx(ctx, nil) @@ -203,17 +177,16 @@ func (db *database) CreateRole(ctx context.Context, actorID, roleID gidx.Prefixe return nil, err } - role.logger = db.logger.Named("role") - role.commit = tx.Commit - role.rollback = tx.Rollback - - return &role, nil + return newTransaction(db.logger.With("role_id", role.ID), tx, &role), nil } -// UpdateRole updates an existing role if one exists. -// If no role already exists, a new role is created in the same way as CreateRole. +// UpdateRoleTransaction starts a new transaction to update an existing role if one exists. +// If no role already exists, a new role is created in the same way as CreateRoleTransaction. // If changing the name and the new name results in a duplicate name error, an ErrRoleNameTaken error is returned. -func (db *database) UpdateRole(ctx context.Context, actorID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (*Role, error) { +// +// Transaction.Commit() or Transaction.Rollback() should be called if error is nil otherwise the database will hold +// the indexes waiting for the transaction to complete. +func (db *database) UpdateRoleTransaction(ctx context.Context, actorID, roleID gidx.PrefixedID, name string, resourceID gidx.PrefixedID) (TxRole, error) { var role Role tx, err := db.BeginTx(ctx, nil) @@ -245,15 +218,15 @@ func (db *database) UpdateRole(ctx context.Context, actorID, roleID gidx.Prefixe return nil, err } - role.logger = db.logger.Named("role") - role.commit = tx.Commit - role.rollback = tx.Rollback - - return &role, nil + return newTransaction(db.logger.With("role_id", role.ID), tx, &role), nil } -// DeleteRole deletes the role id provided, if no rows are affected an ErrNoRoleFound error is returned. -func (db *database) DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (*Role, error) { +// DeleteRoleTransaction starts a new transaction to delete the role for the id provided. +// If no rows are affected an ErrNoRoleFound error is returned. +// +// Transaction.Commit() or Transaction.Rollback() should be called if error is nil otherwise the database will hold +// the indexes waiting for the transaction to complete. +func (db *database) DeleteRoleTransaction(ctx context.Context, roleID gidx.PrefixedID) (TxRole, error) { tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err @@ -273,10 +246,9 @@ func (db *database) DeleteRole(ctx context.Context, roleID gidx.PrefixedID) (*Ro return nil, ErrNoRoleFound } - return &Role{ - ID: roleID, - logger: db.logger.Named("role"), - commit: tx.Commit, - rollback: tx.Rollback, - }, nil + role := Role{ + ID: roleID, + } + + return newTransaction(db.logger.With("role_id", role.ID), tx, &role), nil } diff --git a/internal/database/roles_test.go b/internal/database/roles_test.go index 68459b7fe..16b87cda3 100644 --- a/internal/database/roles_test.go +++ b/internal/database/roles_test.go @@ -29,11 +29,11 @@ func TestGetRoleByID(t *testing.T) { assert.ErrorIs(t, err, database.ErrNoRoleFound) require.Nil(t, role, "no role expected to be returned") - createdRole, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + tx, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) require.NoError(t, err, "no error expected while seeding database role") - err = createdRole.Commit() + err = tx.Commit() require.NoError(t, err, "no error expected while committing role creation") @@ -47,8 +47,8 @@ func TestGetRoleByID(t *testing.T) { assert.Equal(t, roleName, role.Name) assert.Equal(t, resourceID, role.ResourceID) assert.Equal(t, actorID, role.CreatorID) - assert.Equal(t, createdRole.CreatedAt, role.CreatedAt) - assert.Equal(t, createdRole.UpdatedAt, role.UpdatedAt) + assert.Equal(t, tx.Record.CreatedAt, role.CreatedAt) + assert.Equal(t, tx.Record.UpdatedAt, role.UpdatedAt) } func TestGetResourceRoleByName(t *testing.T) { @@ -68,11 +68,11 @@ func TestGetResourceRoleByName(t *testing.T) { assert.ErrorIs(t, err, database.ErrNoRoleFound) require.Nil(t, role, "role expected to be returned") - createdRole, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + roleTx, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) require.NoError(t, err, "no error expected while seeding database role") - err = createdRole.Commit() + err = roleTx.Commit() require.NoError(t, err, "no error expected while committing role creation") @@ -86,8 +86,8 @@ func TestGetResourceRoleByName(t *testing.T) { assert.Equal(t, roleName, role.Name) assert.Equal(t, resourceID, role.ResourceID) assert.Equal(t, actorID, role.CreatorID) - assert.Equal(t, createdRole.CreatedAt, role.CreatedAt) - assert.Equal(t, createdRole.UpdatedAt, role.UpdatedAt) + assert.Equal(t, roleTx.Record.CreatedAt, role.CreatedAt) + assert.Equal(t, roleTx.Record.UpdatedAt, role.UpdatedAt) } func TestListResourceRoles(t *testing.T) { @@ -112,11 +112,11 @@ func TestListResourceRoles(t *testing.T) { } for roleName, roleID := range groups { - role, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + roleTx, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) - require.NoError(t, err, "no error expected creating role", roleName) + require.NoError(t, err, "no error expected creating role transaction", roleName) - err = role.Commit() + err = roleTx.Commit() require.NoError(t, err, "no error expected while committing role", roleName) } @@ -139,7 +139,7 @@ func TestListResourceRoles(t *testing.T) { } } -func TestCreateRole(t *testing.T) { +func TestCreateRoleTransaction(t *testing.T) { db, dbClose := testdb.NewTestDatabase(t) defer dbClose() @@ -150,28 +150,28 @@ func TestCreateRole(t *testing.T) { roleName := "admins" resourceID := gidx.PrefixedID("testten-jkl789") - role, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + roleTx, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) require.NoError(t, err, "no error expected while creating role") - err = role.Commit() + err = roleTx.Commit() require.NoError(t, err, "no error expected while committing role creation") - assert.Equal(t, roleID, role.ID) - assert.Equal(t, roleName, role.Name) - assert.Equal(t, resourceID, role.ResourceID) - assert.Equal(t, actorID, role.CreatorID) - assert.False(t, role.CreatedAt.IsZero()) - assert.False(t, role.UpdatedAt.IsZero()) + assert.Equal(t, roleID, roleTx.Record.ID) + assert.Equal(t, roleName, roleTx.Record.Name) + assert.Equal(t, resourceID, roleTx.Record.ResourceID) + assert.Equal(t, actorID, roleTx.Record.CreatorID) + assert.False(t, roleTx.Record.CreatedAt.IsZero()) + assert.False(t, roleTx.Record.UpdatedAt.IsZero()) - dupeRole, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + dupeRole, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) assert.Error(t, err, "expected error for duplicate index") assert.ErrorIs(t, err, database.ErrRoleAlreadyExists, "expected error to be for role already exists") require.Nil(t, dupeRole, "expected role to be nil") - takenNameRole, err := db.CreateRole(ctx, actorID, roleID2, roleName, resourceID) + takenNameRole, err := db.CreateRoleTransaction(ctx, actorID, roleID2, roleName, resourceID) assert.Error(t, err, "expected error for already taken name") assert.ErrorIs(t, err, database.ErrRoleNameTaken, "expected error to be for already taken name") @@ -191,50 +191,51 @@ func TestUpdateRole(t *testing.T) { roleName2 := "temps" resourceID := gidx.PrefixedID("testten-jkl789") - createdRole, err := db.CreateRole(ctx, createActorID, roleID1, roleName, resourceID) + createdRoleTx, err := db.CreateRoleTransaction(ctx, createActorID, roleID1, roleName, resourceID) require.NoError(t, err, "no error expected while seeding database role") - err = createdRole.Commit() + err = createdRoleTx.Commit() require.NoError(t, err, "no error expected while committing role creation") - createdRole2, err := db.CreateRole(ctx, createActorID, roleID2, roleName2, resourceID) + createdRole2Tx, err := db.CreateRoleTransaction(ctx, createActorID, roleID2, roleName2, resourceID) require.NoError(t, err, "no error expected while seeding database role 2") - err = createdRole2.Commit() + err = createdRole2Tx.Commit() require.NoError(t, err, "no error expected while committing role 2 creation") updateActorID := gidx.PrefixedID("idntusr-abc456") t.Run("update error", func(t *testing.T) { - role, err := db.UpdateRole(ctx, updateActorID, roleID2, roleName, resourceID) + roleTx, err := db.UpdateRoleTransaction(ctx, updateActorID, roleID2, roleName, resourceID) assert.Error(t, err, "expected error updating role name to an already taken role name") assert.ErrorIs(t, err, database.ErrRoleNameTaken, "expected error to be role name taken error") - assert.Nil(t, role, "expected role to be nil") + assert.Nil(t, roleTx, "expected role to be nil") }) updateRoleName := "new-admins" updateResourceID := gidx.PrefixedID("testten-mno101") t.Run("existing role", func(t *testing.T) { - role, err := db.UpdateRole(ctx, updateActorID, roleID1, updateRoleName, updateResourceID) + updateTx, err := db.UpdateRoleTransaction(ctx, updateActorID, roleID1, updateRoleName, updateResourceID) require.NoError(t, err, "no error expected while updating role") - require.NotNil(t, role, "role expected to be returned") + require.NotNil(t, updateTx, "transaction expected to be returned") + require.NotNil(t, updateTx.Record, "role expected to be returned") - err = role.Commit() + err = updateTx.Commit() require.NoError(t, err, "no error expected while committing role update") - assert.Equal(t, roleID1, role.ID) - assert.Equal(t, updateRoleName, role.Name) - assert.Equal(t, updateResourceID, role.ResourceID) - assert.Equal(t, createActorID, role.CreatorID) - assert.Equal(t, createdRole.CreatedAt, role.CreatedAt) - assert.NotEqual(t, createdRole.UpdatedAt, role.UpdatedAt) + assert.Equal(t, roleID1, updateTx.Record.ID) + assert.Equal(t, updateRoleName, updateTx.Record.Name) + assert.Equal(t, updateResourceID, updateTx.Record.ResourceID) + assert.Equal(t, createActorID, updateTx.Record.CreatorID) + assert.Equal(t, createdRoleTx.Record.CreatedAt, updateTx.Record.CreatedAt) + assert.NotEqual(t, createdRoleTx.Record.UpdatedAt, updateTx.Record.UpdatedAt) }) t.Run("new role", func(t *testing.T) { @@ -242,22 +243,23 @@ func TestUpdateRole(t *testing.T) { newRoleName := "users" newResourceID := gidx.PrefixedID("testten-lmn159") - role, err := db.UpdateRole(ctx, updateActorID, newRoleID, newRoleName, newResourceID) + updateTx, err := db.UpdateRoleTransaction(ctx, updateActorID, newRoleID, newRoleName, newResourceID) require.NoError(t, err, "no error expected while updating role") - require.NotNil(t, role, "role expected to be returned") + require.NotNil(t, updateTx, "transaction expected to be returned") + require.NotNil(t, updateTx.Record, "role expected to be returned") - err = role.Commit() + err = updateTx.Commit() require.NoError(t, err, "no error expected while committing new role from update") - assert.Equal(t, newRoleID, role.ID) - assert.Equal(t, newRoleName, role.Name) - assert.Equal(t, newResourceID, role.ResourceID) - assert.Equal(t, updateActorID, role.CreatorID) - assert.False(t, createdRole.CreatedAt.IsZero()) - assert.False(t, createdRole.UpdatedAt.IsZero()) + assert.Equal(t, newRoleID, updateTx.Record.ID) + assert.Equal(t, newRoleName, updateTx.Record.Name) + assert.Equal(t, newResourceID, updateTx.Record.ResourceID) + assert.Equal(t, updateActorID, updateTx.Record.CreatorID) + assert.False(t, createdRoleTx.Record.CreatedAt.IsZero()) + assert.False(t, createdRoleTx.Record.UpdatedAt.IsZero()) }) } @@ -271,28 +273,28 @@ func TestDeleteRole(t *testing.T) { roleName := "admins" resourceID := gidx.PrefixedID("testten-jkl789") - _, err := db.DeleteRole(ctx, roleID) + _, err := db.DeleteRoleTransaction(ctx, roleID) require.Error(t, err, "error expected while deleting role which doesn't exist") require.ErrorIs(t, err, database.ErrNoRoleFound, "expected no role found error for missing role") - role, err := db.CreateRole(ctx, actorID, roleID, roleName, resourceID) + createTx, err := db.CreateRoleTransaction(ctx, actorID, roleID, roleName, resourceID) require.NoError(t, err, "no error expected while seeding database role") - err = role.Commit() + err = createTx.Commit() require.NoError(t, err, "no error expected while committing role creation") - role, err = db.DeleteRole(ctx, roleID) + deleteTx, err := db.DeleteRoleTransaction(ctx, roleID) require.NoError(t, err, "no error expected while deleting role") - err = role.Commit() + err = deleteTx.Commit() require.NoError(t, err, "no error expected while committing role deletion") - role, err = db.GetRoleByID(ctx, roleID) + role, err := db.GetRoleByID(ctx, roleID) require.Error(t, err, "expected error retrieving role") assert.ErrorIs(t, err, database.ErrNoRoleFound, "expected no rows error") diff --git a/internal/database/transactions.go b/internal/database/transactions.go new file mode 100644 index 000000000..a3ec2ac7c --- /dev/null +++ b/internal/database/transactions.go @@ -0,0 +1,43 @@ +package database + +import ( + "database/sql" + "errors" + + "go.uber.org/zap" +) + +// Transaction represents an in flight change being made to the database that must be committed or rolled back. +type Transaction[T any] struct { + logger *zap.SugaredLogger + tx *sql.Tx + + Record T +} + +// Commit completes the transaction and writes the changes to the database. +func (t *Transaction[T]) Commit() error { + return t.tx.Commit() +} + +// Rollback reverts the transaction and discards the changes from the database. +// +// To simplify rollbacks, logging has automatically been setup to log any errors produced if a rollback fails. +func (t *Transaction[T]) Rollback() error { + err := t.tx.Rollback() + if err != nil && !errors.Is(err, sql.ErrTxDone) { + t.logger.Errorw("failed to rollback transaction", zap.Error(err)) + } + + return err +} + +// newTransaction creates a new Transaction with the required fields. +func newTransaction[T any](logger *zap.SugaredLogger, tx *sql.Tx, record T) *Transaction[T] { + return &Transaction[T]{ + logger: logger, + tx: tx, + + Record: record, + } +} diff --git a/internal/query/relations.go b/internal/query/relations.go index ab6395bc7..1a5b4421a 100644 --- a/internal/query/relations.go +++ b/internal/query/relations.go @@ -277,12 +277,12 @@ func (e *engine) CreateRole(ctx context.Context, actor, res types.Resource, role role := newRole(roleName, actions) roleRels := e.roleRelationships(role, res) - dbRole, err := e.db.CreateRole(ctx, actor.ID, role.ID, roleName, res.ID) + dbTx, err := e.db.CreateRoleTransaction(ctx, actor.ID, role.ID, roleName, res.ID) if err != nil { return types.Role{}, err } - defer dbRole.Rollback() //nolint:errcheck // error is logged in function + defer dbTx.Rollback() //nolint:errcheck // error is logged in function request := &pb.WriteRelationshipsRequest{Updates: roleRels} @@ -293,17 +293,17 @@ func (e *engine) CreateRole(ctx context.Context, actor, res types.Resource, role return types.Role{}, err } - if err = dbRole.Commit(); err != nil { + if err = dbTx.Commit(); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return types.Role{}, err } - role.Creator = dbRole.CreatorID - role.ResourceID = dbRole.ResourceID - role.CreatedAt = dbRole.CreatedAt - role.UpdatedAt = dbRole.UpdatedAt + role.Creator = dbTx.Record.CreatorID + role.ResourceID = dbTx.Record.ResourceID + role.CreatedAt = dbTx.Record.CreatedAt + role.UpdatedAt = dbTx.Record.UpdatedAt return role, nil } @@ -383,18 +383,18 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou } var ( - dbRole *database.Role - dbErr error + dbTx database.TxRole + dbErr error ) // If new name has changed, commit change to permissions database. if newName != "" && role.Name != newName { - dbRole, dbErr = e.db.UpdateRole(ctx, actor.ID, role.ID, newName, resourceID) + dbTx, dbErr = e.db.UpdateRoleTransaction(ctx, actor.ID, role.ID, newName, resourceID) if dbErr != nil { return types.Role{}, dbErr } - defer dbRole.Rollback() //nolint:errcheck // error is logged in function + defer dbTx.Rollback() //nolint:errcheck // error is logged in function } // If a change in actions, apply changes to spicedb. @@ -413,20 +413,20 @@ func (e *engine) UpdateRole(ctx context.Context, actor, roleResource types.Resou role.Actions = newActions } - // Only commit if dbRole is defined meaning the name was also updated. - if dbRole != nil { - if err = dbRole.Commit(); err != nil { + // Only commit if dbTx is defined meaning the name was also updated. + if dbTx != nil { + if err = dbTx.Commit(); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error()) return types.Role{}, err } - role.Name = dbRole.Name - role.Creator = dbRole.CreatorID - role.ResourceID = dbRole.ResourceID - role.CreatedAt = dbRole.CreatedAt - role.UpdatedAt = dbRole.UpdatedAt + role.Name = dbTx.Record.Name + role.Creator = dbTx.Record.CreatorID + role.ResourceID = dbTx.Record.ResourceID + role.CreatedAt = dbTx.Record.CreatedAt + role.UpdatedAt = dbTx.Record.UpdatedAt } return role, nil @@ -1048,7 +1048,7 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er } } - dbRole, err := e.db.DeleteRole(ctx, roleResource.ID) + dbTx, err := e.db.DeleteRoleTransaction(ctx, roleResource.ID) if err != nil { // If the role doesn't exist, simply ignore. if !errors.Is(err, database.ErrNoRoleFound) { @@ -1056,7 +1056,7 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er } } else { // Setup rollback in case an error occurs before we commit. - defer dbRole.Rollback() //nolint:errcheck // error is logged in function + defer dbTx.Rollback() //nolint:errcheck // error is logged in function } for _, filter := range filters { @@ -1070,9 +1070,9 @@ func (e *engine) DeleteRole(ctx context.Context, roleResource types.Resource) er } } - // If the role was not found, dbRole will be nil. - if dbRole != nil { - if err = dbRole.Commit(); err != nil { + // If the role was not found, dbTx will be nil. + if dbTx != nil { + if err = dbTx.Commit(); err != nil { span.RecordError(err) span.SetStatus(codes.Error, err.Error())