diff --git a/service/pkg/db/db.go b/service/pkg/db/db.go index 038ee0543d..3db39d11e0 100644 --- a/service/pkg/db/db.go +++ b/service/pkg/db/db.go @@ -54,6 +54,7 @@ func (t Table) Field(field string) string { // We can rename this but wanted to get mocks working. type PgxIface interface { Acquire(ctx context.Context) (*pgxpool.Conn, error) + Begin(ctx context.Context) (pgx.Tx, error) Exec(context.Context, string, ...any) (pgconn.CommandTag, error) QueryRow(context.Context, string, ...any) pgx.Row Query(context.Context, string, ...any) (pgx.Rows, error) diff --git a/service/policy/db/policy.go b/service/policy/db/policy.go index cb2304500c..5889f9398f 100644 --- a/service/policy/db/policy.go +++ b/service/policy/db/policy.go @@ -1,6 +1,9 @@ package db import ( + "context" + + "github.com/jackc/pgx/v5" "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/pkg/db" @@ -19,6 +22,18 @@ type PolicyDBClient struct { *Queries } +func (c *PolicyDBClient) BeginTx(ctx context.Context) (pgx.Tx, error) { + tx, err := c.Client.Pgx.Begin(ctx) + if err != nil { + return nil, err + } + return tx, nil +} + +func (c *PolicyDBClient) WithTx(tx pgx.Tx) *PolicyDBClient { + return &PolicyDBClient{c.Client, c.logger, c.Queries.WithTx(tx)} +} + func NewClient(c *db.Client, logger *logger.Logger) PolicyDBClient { return PolicyDBClient{c, logger, New(c.Pgx)} } diff --git a/service/policy/db/utils.go b/service/policy/db/utils.go index 30864bb909..657d71eab4 100644 --- a/service/policy/db/utils.go +++ b/service/policy/db/utils.go @@ -1,9 +1,11 @@ package db import ( + "context" "fmt" "github.com/google/uuid" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" "github.com/opentdf/platform/protocol/go/common" "github.com/opentdf/platform/protocol/go/policy" @@ -59,3 +61,9 @@ func pgtypeBool(b bool) pgtype.Bool { Valid: true, } } + +// Helper function for swallowing the error in a Pgx Transaction rollback per the documentation +func TxRollback(ctx context.Context, tx pgx.Tx) { + //nolint:errcheck // noop https://pkg.go.dev/github.com/jackc/pgx#hdr-Transactions + tx.Rollback(ctx) +} diff --git a/service/policy/namespaces/namespaces.go b/service/policy/namespaces/namespaces.go index 5598f5880c..da68a7f6a0 100644 --- a/service/policy/namespaces/namespaces.go +++ b/service/policy/namespaces/namespaces.go @@ -93,12 +93,24 @@ func (ns NamespacesService) CreateNamespace(ctx context.Context, req *namespaces } rsp := &namespaces.CreateNamespaceResponse{} - n, err := ns.dbClient.CreateNamespace(ctx, req) + tx, err := ns.dbClient.BeginTx(ctx) + if err != nil { + ns.logger.Audit.PolicyCRUDFailure(ctx, auditParams) + return nil, db.StatusifyError(err, "begin txn failed", slog.String("name", req.GetName())) + } + defer policydb.TxRollback(ctx, tx) + + n, err := ns.dbClient.WithTx(tx).CreateNamespace(ctx, req) if err != nil { ns.logger.Audit.PolicyCRUDFailure(ctx, auditParams) return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("name", req.GetName())) } + if err = tx.Commit(ctx); err != nil { + ns.logger.Audit.PolicyCRUDFailure(ctx, auditParams) + return nil, db.StatusifyError(err, "commit txn failed", slog.String("name", req.GetName())) + } + auditParams.ObjectID = n.GetId() auditParams.Original = n ns.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)