Skip to content

Commit 90edbde

Browse files
authored
feat(policy): 1500 Attribute create with Values (one RPC Call) should employ a db transaction (#1778)
### Proposed Changes fix #1500 - adds `PolicyDbClient.RunInTx()` method - updates attribute create to use the method ### Checklist - [ ] I have added or updated unit tests - [ ] I have added or updated integration tests (if appropriate) - [ ] I have added or updated documentation ### Testing Instructions
1 parent 33f1afd commit 90edbde

File tree

5 files changed

+178
-9
lines changed

5 files changed

+178
-9
lines changed

service/integration/policy_test.go

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package integration
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"log/slog"
7+
"testing"
8+
9+
"github.com/opentdf/platform/service/internal/fixtures"
10+
"github.com/opentdf/platform/service/policy/db"
11+
"github.com/stretchr/testify/suite"
12+
)
13+
14+
type PolicyDBClientSuite struct {
15+
suite.Suite
16+
f fixtures.Fixtures
17+
db fixtures.DBInterface
18+
ctx context.Context //nolint:containedctx // context is used in the test suite
19+
}
20+
21+
func (s *PolicyDBClientSuite) SetupSuite() {
22+
s.ctx = context.Background()
23+
c := *Config
24+
c.DB.Schema = "text_opentdf_policy_db_client"
25+
s.db = fixtures.NewDBInterface(c)
26+
s.f = fixtures.NewFixture(s.db)
27+
s.f.Provision()
28+
}
29+
30+
func (s *PolicyDBClientSuite) TearDownSuite() {
31+
slog.Info("tearing down db.PolicyDbClient test suite")
32+
s.f.TearDown()
33+
}
34+
35+
func (s *PolicyDBClientSuite) Test_RunInTx_CommitsOnSuccess() {
36+
var (
37+
nsName = "success.com"
38+
attrName = fmt.Sprintf("http://%s/attr/attr_one", nsName)
39+
attrValue = fmt.Sprintf("http://%s/attr/%s/value/attr_one_value", nsName, attrName)
40+
41+
nsID string
42+
attrID string
43+
valID string
44+
err error
45+
)
46+
47+
txErr := s.db.PolicyClient.RunInTx(s.ctx, func(txClient *db.PolicyDBClient) error {
48+
nsID, err = txClient.Queries.CreateNamespace(s.ctx, db.CreateNamespaceParams{
49+
Name: nsName,
50+
})
51+
s.Require().NoError(err)
52+
s.Require().NotNil(nsID)
53+
54+
attrID, err = txClient.Queries.CreateAttribute(s.ctx, db.CreateAttributeParams{
55+
NamespaceID: nsID,
56+
Name: attrName,
57+
Rule: db.AttributeDefinitionRuleALLOF,
58+
})
59+
s.Require().NoError(err)
60+
s.Require().NotNil(attrID)
61+
62+
valID, err = txClient.Queries.CreateAttributeValue(s.ctx, db.CreateAttributeValueParams{
63+
AttributeDefinitionID: attrID,
64+
Value: attrValue,
65+
})
66+
s.Require().NoError(err)
67+
s.Require().NotNil(valID)
68+
69+
return nil
70+
})
71+
s.Require().NoError(txErr)
72+
73+
ns, err := s.db.PolicyClient.GetNamespace(s.ctx, nsID)
74+
s.Require().NoError(err)
75+
s.Equal(nsName, ns.GetName())
76+
77+
attr, err := s.db.PolicyClient.GetAttribute(s.ctx, attrID)
78+
s.Require().NoError(err)
79+
s.Equal(attrName, attr.GetName())
80+
81+
attrVal, err := s.db.PolicyClient.GetAttributeValue(s.ctx, valID)
82+
s.Require().NoError(err)
83+
s.Equal(attrValue, attrVal.GetValue())
84+
}
85+
86+
func (s *PolicyDBClientSuite) Test_RunInTx_RollsBackOnFailure() {
87+
var (
88+
nsName = "failure.com"
89+
attrName = fmt.Sprintf("http://%s/attr/attr_one", nsName)
90+
91+
nsID string
92+
attrID string
93+
err error
94+
)
95+
96+
txErr := s.db.PolicyClient.RunInTx(s.ctx, func(txClient *db.PolicyDBClient) error {
97+
nsID, err = txClient.Queries.CreateNamespace(s.ctx, db.CreateNamespaceParams{
98+
Name: nsName,
99+
})
100+
s.Require().NoError(err)
101+
s.Require().NotNil(nsID)
102+
103+
attrID, err = txClient.Queries.CreateAttribute(s.ctx, db.CreateAttributeParams{
104+
NamespaceID: "invalid_ns_id",
105+
Name: attrName,
106+
Rule: db.AttributeDefinitionRuleALLOF,
107+
})
108+
s.Require().Error(err)
109+
s.Require().Zero(attrID)
110+
return err
111+
})
112+
s.Require().Error(txErr)
113+
114+
ns, err := s.db.PolicyClient.GetNamespace(s.ctx, nsID)
115+
s.Require().Error(err)
116+
s.Nil(ns)
117+
118+
attr, err := s.db.PolicyClient.GetAttribute(s.ctx, attrID)
119+
s.Require().Error(err)
120+
s.Nil(attr)
121+
}
122+
123+
func TestPolicySuite(t *testing.T) {
124+
if testing.Short() {
125+
t.Skip("skipping policy integration tests")
126+
}
127+
suite.Run(t, new(PolicyDBClientSuite))
128+
}

service/pkg/db/db.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ func (t Table) Field(field string) string {
5454
// We can rename this but wanted to get mocks working.
5555
type PgxIface interface {
5656
Acquire(ctx context.Context) (*pgxpool.Conn, error)
57+
Begin(ctx context.Context) (pgx.Tx, error)
5758
Exec(context.Context, string, ...any) (pgconn.CommandTag, error)
5859
QueryRow(context.Context, string, ...any) pgx.Row
5960
Query(context.Context, string, ...any) (pgx.Rows, error)

service/pkg/db/errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ var (
2222
ErrUUIDInvalid = errors.New("ErrUUIDInvalid: value not a valid UUID")
2323
ErrMissingValue = errors.New("ErrMissingValue: value must be included")
2424
ErrListLimitTooLarge = errors.New("ErrListLimitTooLarge: requested limit greater than configured maximum")
25+
ErrTxBeginFailed = errors.New("ErrTxBeginFailed: failed to begin DB transaction")
26+
ErrTxRollbackFailed = errors.New("ErrTxRollbackFailed: failed to rollback DB transaction")
27+
ErrTxCommitFailed = errors.New("ErrTxCommitFailed: failed to commit DB transaction")
2528
)
2629

2730
// Get helpful error message for PostgreSQL violation

service/policy/attributes/attributes.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,26 @@ func (s AttributesService) CreateAttribute(ctx context.Context,
5454
ActionType: audit.ActionTypeCreate,
5555
}
5656

57-
item, err := s.dbClient.CreateAttribute(ctx, req.Msg)
57+
err := s.dbClient.RunInTx(ctx, func(txClient *policydb.PolicyDBClient) error {
58+
item, err := txClient.CreateAttribute(ctx, req.Msg)
59+
if err != nil {
60+
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
61+
return err
62+
}
63+
64+
s.logger.Debug("created new attribute definition", slog.String("name", req.Msg.GetName()))
65+
66+
auditParams.ObjectID = item.GetId()
67+
auditParams.Original = item
68+
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)
69+
70+
rsp.Attribute = item
71+
return nil
72+
})
5873
if err != nil {
59-
s.logger.Audit.PolicyCRUDFailure(ctx, auditParams)
6074
return nil, db.StatusifyError(err, db.ErrTextCreationFailed, slog.String("attribute", req.Msg.String()))
6175
}
6276

63-
s.logger.Debug("created new attribute definition", slog.String("name", req.Msg.GetName()))
64-
65-
auditParams.ObjectID = item.GetId()
66-
auditParams.Original = item
67-
s.logger.Audit.PolicyCRUDSuccess(ctx, auditParams)
68-
69-
rsp.Attribute = item
7077
return connect.NewResponse(rsp), nil
7178
}
7279

service/policy/db/policy.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package db
22

33
import (
4+
"context"
5+
"fmt"
6+
47
"github.com/opentdf/platform/protocol/go/common"
58
"github.com/opentdf/platform/service/logger"
69
"github.com/opentdf/platform/service/pkg/db"
@@ -31,6 +34,33 @@ func NewClient(c *db.Client, logger *logger.Logger, configuredListLimitMax, conf
3134
return PolicyDBClient{c, logger, New(c.Pgx), ListConfig{limitDefault: configuredListLimitDefault, limitMax: configuredListLimitMax}}
3235
}
3336

37+
func (c *PolicyDBClient) RunInTx(ctx context.Context, query func(txClient *PolicyDBClient) error) error {
38+
tx, err := c.Client.Pgx.Begin(ctx)
39+
if err != nil {
40+
return fmt.Errorf("%w: %w", db.ErrTxBeginFailed, err)
41+
}
42+
43+
txClient := &PolicyDBClient{c.Client, c.logger, c.Queries.WithTx(tx), c.listCfg}
44+
45+
err = query(txClient)
46+
if err != nil {
47+
c.logger.WarnContext(ctx, "error during DB transaction, rolling back")
48+
49+
if rollbackErr := tx.Rollback(ctx); rollbackErr != nil {
50+
// this should never happen, but if it does, we want to know about it
51+
return fmt.Errorf("%w, transaction [%w]: %w", db.ErrTxRollbackFailed, err, rollbackErr)
52+
}
53+
54+
return err
55+
}
56+
57+
if err = tx.Commit(ctx); err != nil {
58+
return fmt.Errorf("%w: %w", db.ErrTxCommitFailed, err)
59+
}
60+
61+
return nil
62+
}
63+
3464
func getDBStateTypeTransformedEnum(state common.ActiveStateEnum) transformedState {
3565
switch state.String() {
3666
case common.ActiveStateEnum_ACTIVE_STATE_ENUM_ACTIVE.String():

0 commit comments

Comments
 (0)