diff --git a/service/integration/obligation_triggers_test.go b/service/integration/obligation_triggers_test.go index 553379023..2141ba2d5 100644 --- a/service/integration/obligation_triggers_test.go +++ b/service/integration/obligation_triggers_test.go @@ -2,6 +2,7 @@ package integration import ( "context" + "fmt" "log/slog" "testing" @@ -29,16 +30,28 @@ const ( type ObligationTriggersSuite struct { suite.Suite - ctx context.Context //nolint:containedctx // context is used in the test suite - db fixtures.DBInterface - f fixtures.Fixtures - namespace *policy.Namespace - attribute *policy.Attribute - attributeValue *policy.Value - action *policy.Action - obligation *policy.Obligation - obligationValue *policy.ObligationValue - triggerIDsToClean []string + ctx context.Context //nolint:containedctx // context is used in the test suite + db fixtures.DBInterface + f fixtures.Fixtures + namespace *policy.Namespace + attribute *policy.Attribute + attributeValue *policy.Value + action *policy.Action + obligation *policy.Obligation + obligationValue *policy.ObligationValue + triggerIDsToClean []string + obligationValueIDsToClean []string +} + +type DifferentNamespaceEntities struct { + Namespace *policy.Namespace + Obligation *policy.Obligation + ObligationValue *policy.ObligationValue + Attribute *policy.Attribute + AttributeValue *policy.Value + Trigger *policy.ObligationTrigger + CleanupNamespace func() + CleanupTrigger func() } func (s *ObligationTriggersSuite) SetupSuite() { @@ -119,7 +132,14 @@ func (s *ObligationTriggersSuite) TearDownTest() { }) s.Require().NoError(err) } + for _, obligationValueID := range s.obligationValueIDsToClean { + _, err := s.db.PolicyClient.DeleteObligationValue(s.ctx, &obligations.DeleteObligationValueRequest{ + Id: obligationValueID, + }) + s.Require().NoError(err) + } s.triggerIDsToClean = nil + s.obligationValueIDsToClean = nil } func TestObligationTriggersSuite(t *testing.T) { @@ -146,7 +166,7 @@ func (s *ObligationTriggersSuite) Test_CreateObligationTrigger_WithIDs_Success() }) s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) s.Require().NoError(err) - s.validateTrigger(trigger, true) + s.validateTriggerWithDefaults(trigger, true) s.Require().Equal("test", trigger.GetMetadata().GetLabels()["source"]) } @@ -158,7 +178,7 @@ func (s *ObligationTriggersSuite) Test_CreateObligationTrigger_NoCtx_Success() { }) s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) s.Require().NoError(err) - s.validateTrigger(trigger, false) + s.validateTriggerWithDefaults(trigger, false) } func (s *ObligationTriggersSuite) Test_CreateObligationTrigger_WithNameFQN_Success() { @@ -177,7 +197,7 @@ func (s *ObligationTriggersSuite) Test_CreateObligationTrigger_WithNameFQN_Succe }) s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) s.Require().NoError(err) - s.validateTrigger(trigger, true) + s.validateTriggerWithDefaults(trigger, true) s.Require().Equal("test", trigger.GetMetadata().GetLabels()["source"]) } @@ -293,6 +313,202 @@ func (s *ObligationTriggersSuite) Test_DeleteObligationTrigger_NotFound_Fails() s.Require().ErrorIs(err, db.ErrNotFound) } +// ListObligationTriggers tests +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_NoTriggersNoFilter_Success() { + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{}) + s.Require().NoError(err) + s.Require().Empty(triggers) + s.Require().NotNil(pageResult) + s.validatePageResponses(pageResult, 0, 0, 0) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_NoTriggersWithNamespaceId_Success() { + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceId: s.namespace.GetId(), + }) + s.Require().NoError(err) + s.Require().Empty(triggers) + s.validatePageResponses(pageResult, 0, 0, 0) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_NoTriggersWithNamespaceFqn_Success() { + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceFqn: s.namespace.GetFqn(), + }) + s.Require().NoError(err) + s.Require().Empty(triggers) + s.validatePageResponses(pageResult, 0, 0, 0) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_MultipleTriggersNoFilter_MultipleNamespaces_Success() { + createdTriggersMap := s.createMultipleUniqueTriggers(2) + s.appendObligationValuesToClean(createdTriggersMap) + differentNS := s.createDifferentNamespaceWithTrigger("different-namespace-id-test") + defer differentNS.CleanupNamespace() + defer differentNS.CleanupTrigger() + createdTriggersMap[differentNS.Trigger.GetId()] = differentNS.Trigger + + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{}) + s.Require().NoError(err) + s.Require().Len(triggers, 3) + s.validatePageResponses(pageResult, 3, 0, 0) + + // Verify all triggers are returned + foundTriggers := make(map[string]bool) + for _, t := range triggers { + createdTrigger, ok := createdTriggersMap[t.GetId()] + s.Require().True(ok) + foundTriggers[t.GetId()] = true + s.validateTrigger(t, createdTrigger.GetObligationValue(), createdTrigger.GetAttributeValue(), createdTrigger.GetAction(), true) + } + // Validate all triggers are found + for id := range createdTriggersMap { + s.Require().True(foundTriggers[id]) + } +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_MultipleTriggersWithNamespaceId_MultipleNamespaces_Success() { + createdTriggersMap := s.createMultipleUniqueTriggers(2) + s.appendObligationValuesToClean(createdTriggersMap) + differentNS := s.createDifferentNamespaceWithTrigger("different-namespace-id-test") + defer differentNS.CleanupNamespace() + defer differentNS.CleanupTrigger() + + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceId: s.namespace.GetId(), + }) + s.Require().NoError(err) + s.Require().Len(triggers, 2) + s.validatePageResponses(pageResult, 2, 0, 0) + + foundTriggers := make(map[string]bool) + for _, t := range triggers { + s.Require().Equal(s.namespace.GetId(), t.GetObligationValue().GetObligation().GetNamespace().GetId()) + createdTrigger, ok := createdTriggersMap[t.GetId()] + s.Require().True(ok) + s.validateTrigger(t, createdTrigger.GetObligationValue(), createdTrigger.GetAttributeValue(), createdTrigger.GetAction(), true) + foundTriggers[t.GetId()] = true + } + for id := range createdTriggersMap { + s.Require().True(foundTriggers[id]) + } +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_MultipleTriggersWithNamespaceFqn_Success() { + createdTriggersMap := s.createMultipleUniqueTriggers(2) + s.appendObligationValuesToClean(createdTriggersMap) + differentNS := s.createDifferentNamespaceWithTrigger("different-namespace-id-test") + defer differentNS.CleanupNamespace() + defer differentNS.CleanupTrigger() + + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceFqn: s.namespace.GetFqn(), + }) + s.Require().NoError(err) + s.Require().Len(triggers, 2) + s.validatePageResponses(pageResult, 2, 0, 0) + + foundTriggers := make(map[string]bool) + for _, t := range triggers { + s.Require().Equal(s.namespace.GetFqn(), t.GetObligationValue().GetObligation().GetNamespace().GetFqn()) + createdTrigger, ok := createdTriggersMap[t.GetId()] + s.Require().True(ok) + s.validateTrigger(t, createdTrigger.GetObligationValue(), createdTrigger.GetAttributeValue(), createdTrigger.GetAction(), true) + foundTriggers[t.GetId()] = true + } + for id := range createdTriggersMap { + s.Require().True(foundTriggers[id]) + } +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_WithPagination_Success() { + createdTriggersMap := s.createMultipleUniqueTriggers(5) + s.appendObligationValuesToClean(createdTriggersMap) + var currentOffset int32 + foundTriggers := make(map[string]bool) + var total int32 = 5 + var limit int32 = 2 + for i := 0; i < 3; i++ { + nextOffset := currentOffset + 2 + expectedTriggersCount := 2 + if i == 2 { + nextOffset = 0 + expectedTriggersCount = 1 + } + + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + Pagination: &policy.PageRequest{ + Limit: limit, + Offset: currentOffset, + }, + }) + s.Require().NoError(err) + + for _, t := range triggers { + s.validateTrigger(t, createdTriggersMap[t.GetId()].GetObligationValue(), createdTriggersMap[t.GetId()].GetAttributeValue(), createdTriggersMap[t.GetId()].GetAction(), true) + foundTriggers[t.GetId()] = true + } + s.Require().Len(triggers, expectedTriggersCount) + s.validatePageResponses(pageResult, total, currentOffset, nextOffset) + currentOffset = pageResult.GetNextOffset() + } + s.Require().Len(foundTriggers, 5) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_WithNamespaceAndPagination_Success() { + trigger := s.createGenericTrigger() + s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) + differentNS := s.createDifferentNamespaceWithTrigger("different-namespace-for-list-test") + defer differentNS.CleanupNamespace() + defer differentNS.CleanupTrigger() + + triggers, pageRes, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceId: s.namespace.GetId(), + Pagination: &policy.PageRequest{ + Limit: 1, + }, + }) + s.Require().NoError(err) + s.Require().Len(triggers, 1) + s.validatePageResponses(pageRes, 1, 0, 0) + s.Require().Equal(s.namespace.GetId(), triggers[0].GetObligationValue().GetObligation().GetNamespace().GetId()) + s.validateTriggerWithDefaults(triggers[0], true) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_LimitToLarge() { + triggers, pageRes, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{ + NamespaceId: s.namespace.GetId(), + Pagination: &policy.PageRequest{ + Limit: s.db.LimitMax + 1, + }, + }) + s.Require().ErrorIs(err, db.ErrListLimitTooLarge) + s.Require().Nil(triggers) + s.Require().Nil(pageRes) +} + +func (s *ObligationTriggersSuite) Test_ListObligationTriggers_NoContext_Success() { + // Create a trigger without context + trigger, err := s.db.PolicyClient.CreateObligationTrigger(s.ctx, &obligations.AddObligationTriggerRequest{ + ObligationValue: &common.IdFqnIdentifier{Id: s.obligationValue.GetId()}, + AttributeValue: &common.IdFqnIdentifier{Id: s.attributeValue.GetId()}, + Action: &common.IdNameIdentifier{Id: s.action.GetId()}, + }) + s.Require().NoError(err) + s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) + + // List triggers + triggers, pageResult, err := s.db.PolicyClient.ListObligationTriggers(s.ctx, &obligations.ListObligationTriggersRequest{}) + s.Require().NoError(err) + s.Require().Len(triggers, 1) + s.validatePageResponses(pageResult, 1, 0, 0) + + // Verify the listed trigger has no context + listedTrigger := triggers[0] + s.Require().Equal(trigger.GetId(), listedTrigger.GetId()) + s.validateTriggerWithDefaults(listedTrigger, false) +} + func (s *ObligationTriggersSuite) createGenericTrigger() *policy.ObligationTrigger { trigger, err := s.db.PolicyClient.CreateObligationTrigger(s.ctx, &obligations.AddObligationTriggerRequest{ ObligationValue: &common.IdFqnIdentifier{Id: s.obligationValue.GetId()}, @@ -306,30 +522,163 @@ func (s *ObligationTriggersSuite) createGenericTrigger() *policy.ObligationTrigg }, }) s.Require().NoError(err) - s.validateTrigger(trigger, true) + s.validateTriggerWithDefaults(trigger, true) return trigger } -func (s *ObligationTriggersSuite) validateTrigger(trigger *policy.ObligationTrigger, shouldHaveCtx bool) { - s.Require().NotNil(trigger) - s.Require().NotEmpty(trigger.GetId()) - s.Require().Equal(s.attributeValue.GetId(), trigger.GetAttributeValue().GetId()) - s.Require().Equal(s.attributeValue.GetFqn(), trigger.GetAttributeValue().GetFqn()) - s.Require().Equal(s.attributeValue.GetValue(), trigger.GetAttributeValue().GetValue()) - s.Require().Equal(s.obligationValue.GetId(), trigger.GetObligationValue().GetId()) - s.Require().Equal(s.obligationValue.GetValue(), trigger.GetObligationValue().GetValue()) - s.Require().Equal(s.obligationValue.GetObligation().GetId(), trigger.GetObligationValue().GetObligation().GetId()) - s.Require().Equal(s.obligationValue.GetObligation().GetName(), trigger.GetObligationValue().GetObligation().GetName()) - s.Require().Equal(s.obligationValue.GetObligation().GetNamespace().GetFqn(), trigger.GetObligationValue().GetObligation().GetNamespace().GetFqn()) - s.Require().Empty(trigger.GetObligationValue().GetTriggers()) - s.Require().Equal(s.action.GetId(), trigger.GetAction().GetId()) - s.Require().Equal(s.action.GetName(), trigger.GetAction().GetName()) +func (s *ObligationTriggersSuite) createUniqueTrigger(uniqueSuffix string) *policy.ObligationTrigger { + // Create a unique obligation value for this trigger + uniqueObligationValue, err := s.db.PolicyClient.CreateObligationValue(s.ctx, &obligations.CreateObligationValueRequest{ + ObligationId: s.obligation.GetId(), + Value: obligationValue + "-" + uniqueSuffix, + }) + s.Require().NoError(err) + + trigger, err := s.db.PolicyClient.CreateObligationTrigger(s.ctx, &obligations.AddObligationTriggerRequest{ + ObligationValue: &common.IdFqnIdentifier{Id: uniqueObligationValue.GetId()}, + AttributeValue: &common.IdFqnIdentifier{Id: s.attributeValue.GetId()}, + Action: &common.IdNameIdentifier{Id: s.action.GetId()}, + Metadata: &common.MetadataMutable{}, + Context: &policy.RequestContext{ + Pep: &policy.PolicyEnforcementPoint{ + ClientId: clientID, + }, + }, + }) + s.Require().NoError(err) + + // Validate the trigger with the unique obligation value using the refactored method + s.validateTrigger(trigger, uniqueObligationValue, s.attributeValue, s.action, true) + + return trigger +} + +func (s *ObligationTriggersSuite) createMultipleUniqueTriggers(count int) map[string]*policy.ObligationTrigger { + triggersMap := make(map[string]*policy.ObligationTrigger) + for i := range count { + trigger := s.createUniqueTrigger(fmt.Sprintf("trigger-%d", i)) + triggersMap[trigger.GetId()] = trigger + s.triggerIDsToClean = append(s.triggerIDsToClean, trigger.GetId()) + } + return triggersMap +} + +// validateTrigger validates that the actual trigger matches the expected values +func (s *ObligationTriggersSuite) validateTrigger(actual *policy.ObligationTrigger, expectedObligationValue *policy.ObligationValue, expectedAttributeValue *policy.Value, expectedAction *policy.Action, shouldHaveCtx bool) { + s.Require().NotNil(actual) + s.Require().NotEmpty(actual.GetId()) + + // Validate attribute value + s.Require().Equal(expectedAttributeValue.GetId(), actual.GetAttributeValue().GetId()) + s.Require().Equal(expectedAttributeValue.GetFqn(), actual.GetAttributeValue().GetFqn()) + s.Require().Equal(expectedAttributeValue.GetValue(), actual.GetAttributeValue().GetValue()) + + // Validate obligation value + s.Require().Equal(expectedObligationValue.GetId(), actual.GetObligationValue().GetId()) + s.Require().Equal(expectedObligationValue.GetValue(), actual.GetObligationValue().GetValue()) + s.Require().Equal(expectedObligationValue.GetObligation().GetId(), actual.GetObligationValue().GetObligation().GetId()) + s.Require().Equal(expectedObligationValue.GetObligation().GetName(), actual.GetObligationValue().GetObligation().GetName()) + s.Require().Equal(expectedObligationValue.GetObligation().GetNamespace().GetFqn(), actual.GetObligationValue().GetObligation().GetNamespace().GetFqn()) + s.Require().Empty(actual.GetObligationValue().GetTriggers()) + + // Validate action + s.Require().Equal(expectedAction.GetId(), actual.GetAction().GetId()) + s.Require().Equal(expectedAction.GetName(), actual.GetAction().GetName()) + + // Validate context if shouldHaveCtx { - s.Require().NotNil(trigger.GetContext()) - s.Require().Len(trigger.GetContext(), 1) - s.Require().NotNil(trigger.GetContext()[0].GetPep()) - s.Require().Equal(clientID, trigger.GetContext()[0].GetPep().GetClientId()) + s.Require().NotNil(actual.GetContext()) + s.Require().Len(actual.GetContext(), 1) + s.Require().NotNil(actual.GetContext()[0].GetPep()) + s.Require().Equal(clientID, actual.GetContext()[0].GetPep().GetClientId()) } else { - s.Require().Empty(trigger.GetContext()) + s.Require().Empty(actual.GetContext()) + } +} + +func (s *ObligationTriggersSuite) validatePageResponses(pageResult *policy.PageResponse, expectedTotal, expectedCurrentOffset, expectedNextOffset int32) { + s.Require().NotNil(pageResult) + s.Require().Equal(expectedTotal, pageResult.GetTotal()) + s.Require().Equal(expectedCurrentOffset, pageResult.GetCurrentOffset()) + s.Require().Equal(expectedNextOffset, pageResult.GetNextOffset()) +} + +// validateTriggerWithDefaults validates a trigger against the suite's default values for backward compatibility +func (s *ObligationTriggersSuite) validateTriggerWithDefaults(trigger *policy.ObligationTrigger, shouldHaveCtx bool) { + s.validateTrigger(trigger, s.obligationValue, s.attributeValue, s.action, shouldHaveCtx) +} + +func (s *ObligationTriggersSuite) appendObligationValuesToClean(createdTriggers map[string]*policy.ObligationTrigger) { + for _, trigger := range createdTriggers { + s.obligationValueIDsToClean = append(s.obligationValueIDsToClean, trigger.GetObligationValue().GetId()) + } +} + +func (s *ObligationTriggersSuite) createDifferentNamespaceWithTrigger(namespaceName string) *DifferentNamespaceEntities { + // Create a different namespace + differentNamespace, err := s.db.PolicyClient.CreateNamespace(s.ctx, &namespaces.CreateNamespaceRequest{ + Name: namespaceName, + }) + s.Require().NoError(err) + + // Create obligation in different namespace + differentObligation, err := s.db.PolicyClient.CreateObligation(s.ctx, &obligations.CreateObligationRequest{ + Name: "different-obligation-" + namespaceName, + NamespaceId: differentNamespace.GetId(), + }) + s.Require().NoError(err) + + // Create obligation value in different namespace + differentObligationValue, err := s.db.PolicyClient.CreateObligationValue(s.ctx, &obligations.CreateObligationValueRequest{ + ObligationId: differentObligation.GetId(), + Value: "different-obligation-value-" + namespaceName, + }) + s.Require().NoError(err) + + // Create attribute in different namespace + differentAttribute, err := s.db.PolicyClient.CreateAttribute(s.ctx, &attributes.CreateAttributeRequest{ + Name: "different-attribute-" + namespaceName, + NamespaceId: differentNamespace.GetId(), + Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF, + }) + s.Require().NoError(err) + + // Create attribute value in different namespace + differentAttributeValue, err := s.db.PolicyClient.CreateAttributeValue(s.ctx, differentAttribute.GetId(), &attributes.CreateAttributeValueRequest{ + Value: "different-value-" + namespaceName, + AttributeId: differentAttribute.GetId(), + }) + s.Require().NoError(err) + + // Create trigger in different namespace + differentTrigger, err := s.db.PolicyClient.CreateObligationTrigger(s.ctx, &obligations.AddObligationTriggerRequest{ + ObligationValue: &common.IdFqnIdentifier{Id: differentObligationValue.GetId()}, + AttributeValue: &common.IdFqnIdentifier{Id: differentAttributeValue.GetId()}, + Action: &common.IdNameIdentifier{Id: s.action.GetId()}, + Context: &policy.RequestContext{ + Pep: &policy.PolicyEnforcementPoint{ + ClientId: clientID, + }, + }, + }) + s.Require().NoError(err) + + return &DifferentNamespaceEntities{ + Namespace: differentNamespace, + Obligation: differentObligation, + ObligationValue: differentObligationValue, + Attribute: differentAttribute, + AttributeValue: differentAttributeValue, + Trigger: differentTrigger, + CleanupNamespace: func() { + _, err := s.db.PolicyClient.UnsafeDeleteNamespace(s.ctx, differentNamespace, differentNamespace.GetFqn()) + s.Require().NoError(err) + }, + CleanupTrigger: func() { + _, err := s.db.PolicyClient.DeleteObligationTrigger(s.ctx, &obligations.RemoveObligationTriggerRequest{ + Id: differentTrigger.GetId(), + }) + s.Require().NoError(err) + }, } } diff --git a/service/policy/db/obligations.go b/service/policy/db/obligations.go index 31174d90b..339d43df0 100644 --- a/service/policy/db/obligations.go +++ b/service/policy/db/obligations.go @@ -663,3 +663,53 @@ func (c PolicyDBClient) DeleteObligationTrigger(ctx context.Context, r *obligati Id: id, }, nil } + +func (c PolicyDBClient) ListObligationTriggers(ctx context.Context, r *obligations.ListObligationTriggersRequest) ([]*policy.ObligationTrigger, *policy.PageResponse, error) { + limit, offset := c.getRequestedLimitOffset(r.GetPagination()) + + maxLimit := c.listCfg.limitMax + if maxLimit > 0 && limit > maxLimit { + return nil, nil, db.ErrListLimitTooLarge + } + + rows, err := c.queries.listObligationTriggers(ctx, listObligationTriggersParams{ + NamespaceID: r.GetNamespaceId(), + NamespaceFqn: r.GetNamespaceFqn(), + Offset: offset, + Limit: limit, + }) + if err != nil { + return nil, nil, db.WrapIfKnownInvalidQueryErr(err) + } + + var result []*policy.ObligationTrigger + for _, row := range rows { + metadata := &common.Metadata{} + if err := unmarshalMetadata(row.Metadata, metadata); err != nil { + return nil, nil, err + } + + obligationTrigger, err := unmarshalObligationTrigger(row.Trigger) + if err != nil { + return nil, nil, err + } + + obligationTrigger.Metadata = metadata + result = append(result, obligationTrigger) + } + + var total int32 + var nextOffset int32 + if len(rows) > 0 { + total = int32(rows[0].Total) + nextOffset = getNextOffset(offset, limit, total) + } + + pageResult := &policy.PageResponse{ + CurrentOffset: offset, + Total: total, + NextOffset: nextOffset, + } + + return result, pageResult, nil +} diff --git a/service/policy/db/obligations.sql.go b/service/policy/db/obligations.sql.go index bd97993dc..6d5b4deff 100644 --- a/service/policy/db/obligations.sql.go +++ b/service/policy/db/obligations.sql.go @@ -368,7 +368,7 @@ WITH obligation_lookup AS ( -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) + (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -427,7 +427,7 @@ type createObligationValueRow struct { // -- lookup by obligation id OR by namespace fqn + obligation name // ( // -- lookup by obligation id -// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) +// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) // OR // -- lookup by namespace fqn + obligation name // (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -502,7 +502,7 @@ WHERE id IN ( -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) + (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -530,7 +530,7 @@ type deleteObligationParams struct { // -- lookup by obligation id OR by namespace fqn + obligation name // ( // -- lookup by obligation id -// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) +// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) // OR // -- lookup by namespace fqn + obligation name // (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -574,7 +574,7 @@ WHERE id IN ( -- lookup by value id OR by namespace fqn + obligation name + value name ( -- lookup by value id - (NULLIF($1::TEXT, '') IS NOT NULL AND ov.id = $1::UUID) + (NULLIF($1::TEXT, '') IS NOT NULL AND ov.id = NULLIF($1::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name + value name (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL AND NULLIF($4::TEXT, '') IS NOT NULL @@ -604,7 +604,7 @@ type deleteObligationValueParams struct { // -- lookup by value id OR by namespace fqn + obligation name + value name // ( // -- lookup by value id -// (NULLIF($1::TEXT, '') IS NOT NULL AND ov.id = $1::UUID) +// (NULLIF($1::TEXT, '') IS NOT NULL AND ov.id = NULLIF($1::TEXT, '')::UUID) // OR // -- lookup by namespace fqn + obligation name + value name // (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL AND NULLIF($4::TEXT, '') IS NOT NULL @@ -683,7 +683,7 @@ WHERE -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) + (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -766,7 +766,7 @@ type getObligationRow struct { // -- lookup by obligation id OR by namespace fqn + obligation name // ( // -- lookup by obligation id -// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = $1::UUID) +// (NULLIF($1::TEXT, '') IS NOT NULL AND od.id = NULLIF($1::TEXT, '')::UUID) // OR // -- lookup by namespace fqn + obligation name // (NULLIF($2::TEXT, '') IS NOT NULL AND NULLIF($3::TEXT, '') IS NOT NULL @@ -1295,6 +1295,169 @@ func (q *Queries) getObligationsByFQNs(ctx context.Context, arg getObligationsBy return items, nil } +const listObligationTriggers = `-- name: listObligationTriggers :many +SELECT + JSON_STRIP_NULLS( + JSON_BUILD_OBJECT( + 'id', ot.id, + 'obligation_value', JSON_BUILD_OBJECT( + 'id', ov.id, + 'value', ov.value, + 'obligation', JSON_BUILD_OBJECT( + 'id', od.id, + 'name', od.name, + 'namespace', JSON_BUILD_OBJECT( + 'id', n.id, + 'name', n.name, + 'fqn', COALESCE(ns_fqns.fqn, '') + ) + ) + ), + 'action', JSON_BUILD_OBJECT( + 'id', a.id, + 'name', a.name + ), + 'attribute_value', JSON_BUILD_OBJECT( + 'id', av.id, + 'value', av.value, + 'fqn', COALESCE(av_fqns.fqn, '') + ), + 'context', CASE + WHEN ot.client_id IS NOT NULL THEN JSON_BUILD_ARRAY( + JSON_BUILD_OBJECT( + 'pep', JSON_BUILD_OBJECT( + 'client_id', ot.client_id + ) + ) + ) + ELSE '[]'::JSON + END + ) + ) as trigger, + JSON_STRIP_NULLS( + JSON_BUILD_OBJECT( + 'labels', ot.metadata -> 'labels', + 'created_at', ot.created_at, + 'updated_at', ot.updated_at + ) + ) as metadata, + COUNT(*) OVER() as total +FROM obligation_triggers ot +JOIN obligation_values_standard ov ON ot.obligation_value_id = ov.id +JOIN obligation_definitions od ON ov.obligation_definition_id = od.id +JOIN attribute_namespaces n ON od.namespace_id = n.id +LEFT JOIN attribute_fqns ns_fqns ON ns_fqns.namespace_id = n.id AND ns_fqns.attribute_id IS NULL AND ns_fqns.value_id IS NULL +JOIN actions a ON ot.action_id = a.id +JOIN attribute_values av ON ot.attribute_value_id = av.id +LEFT JOIN attribute_fqns av_fqns ON av_fqns.value_id = av.id +WHERE + (NULLIF($1::TEXT, '') IS NULL OR od.namespace_id = $1::UUID) AND + (NULLIF($2::TEXT, '') IS NULL OR ns_fqns.fqn = $2::VARCHAR) +ORDER BY ot.created_at DESC +LIMIT $4 +OFFSET $3 +` + +type listObligationTriggersParams struct { + NamespaceID string `json:"namespace_id"` + NamespaceFqn string `json:"namespace_fqn"` + Offset int32 `json:"offset_"` + Limit int32 `json:"limit_"` +} + +type listObligationTriggersRow struct { + Trigger []byte `json:"trigger"` + Metadata []byte `json:"metadata"` + Total int64 `json:"total"` +} + +// listObligationTriggers +// +// SELECT +// JSON_STRIP_NULLS( +// JSON_BUILD_OBJECT( +// 'id', ot.id, +// 'obligation_value', JSON_BUILD_OBJECT( +// 'id', ov.id, +// 'value', ov.value, +// 'obligation', JSON_BUILD_OBJECT( +// 'id', od.id, +// 'name', od.name, +// 'namespace', JSON_BUILD_OBJECT( +// 'id', n.id, +// 'name', n.name, +// 'fqn', COALESCE(ns_fqns.fqn, '') +// ) +// ) +// ), +// 'action', JSON_BUILD_OBJECT( +// 'id', a.id, +// 'name', a.name +// ), +// 'attribute_value', JSON_BUILD_OBJECT( +// 'id', av.id, +// 'value', av.value, +// 'fqn', COALESCE(av_fqns.fqn, '') +// ), +// 'context', CASE +// WHEN ot.client_id IS NOT NULL THEN JSON_BUILD_ARRAY( +// JSON_BUILD_OBJECT( +// 'pep', JSON_BUILD_OBJECT( +// 'client_id', ot.client_id +// ) +// ) +// ) +// ELSE '[]'::JSON +// END +// ) +// ) as trigger, +// JSON_STRIP_NULLS( +// JSON_BUILD_OBJECT( +// 'labels', ot.metadata -> 'labels', +// 'created_at', ot.created_at, +// 'updated_at', ot.updated_at +// ) +// ) as metadata, +// COUNT(*) OVER() as total +// FROM obligation_triggers ot +// JOIN obligation_values_standard ov ON ot.obligation_value_id = ov.id +// JOIN obligation_definitions od ON ov.obligation_definition_id = od.id +// JOIN attribute_namespaces n ON od.namespace_id = n.id +// LEFT JOIN attribute_fqns ns_fqns ON ns_fqns.namespace_id = n.id AND ns_fqns.attribute_id IS NULL AND ns_fqns.value_id IS NULL +// JOIN actions a ON ot.action_id = a.id +// JOIN attribute_values av ON ot.attribute_value_id = av.id +// LEFT JOIN attribute_fqns av_fqns ON av_fqns.value_id = av.id +// WHERE +// (NULLIF($1::TEXT, '') IS NULL OR od.namespace_id = $1::UUID) AND +// (NULLIF($2::TEXT, '') IS NULL OR ns_fqns.fqn = $2::VARCHAR) +// ORDER BY ot.created_at DESC +// LIMIT $4 +// OFFSET $3 +func (q *Queries) listObligationTriggers(ctx context.Context, arg listObligationTriggersParams) ([]listObligationTriggersRow, error) { + rows, err := q.db.Query(ctx, listObligationTriggers, + arg.NamespaceID, + arg.NamespaceFqn, + arg.Offset, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []listObligationTriggersRow + for rows.Next() { + var i listObligationTriggersRow + if err := rows.Scan(&i.Trigger, &i.Metadata, &i.Total); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listObligations = `-- name: listObligations :many WITH counted AS ( SELECT COUNT(od.id) AS total diff --git a/service/policy/db/queries/obligations.sql b/service/policy/db/queries/obligations.sql index 61f684681..ac5933d82 100644 --- a/service/policy/db/queries/obligations.sql +++ b/service/policy/db/queries/obligations.sql @@ -109,7 +109,7 @@ WHERE -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = @id::UUID) + (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = NULLIF(@id::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF(@namespace_fqn::TEXT, '') IS NOT NULL AND NULLIF(@name::TEXT, '') IS NOT NULL @@ -208,7 +208,7 @@ WHERE id IN ( -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = @id::UUID) + (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = NULLIF(@id::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF(@namespace_fqn::TEXT, '') IS NOT NULL AND NULLIF(@name::TEXT, '') IS NOT NULL @@ -301,7 +301,7 @@ WITH obligation_lookup AS ( -- lookup by obligation id OR by namespace fqn + obligation name ( -- lookup by obligation id - (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = @id::UUID) + (NULLIF(@id::TEXT, '') IS NOT NULL AND od.id = NULLIF(@id::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name (NULLIF(@namespace_fqn::TEXT, '') IS NOT NULL AND NULLIF(@name::TEXT, '') IS NOT NULL @@ -472,7 +472,7 @@ WHERE id IN ( -- lookup by value id OR by namespace fqn + obligation name + value name ( -- lookup by value id - (NULLIF(@id::TEXT, '') IS NOT NULL AND ov.id = @id::UUID) + (NULLIF(@id::TEXT, '') IS NOT NULL AND ov.id = NULLIF(@id::TEXT, '')::UUID) OR -- lookup by namespace fqn + obligation name + value name (NULLIF(@namespace_fqn::TEXT, '') IS NOT NULL AND NULLIF(@name::TEXT, '') IS NOT NULL AND NULLIF(@value::TEXT, '') IS NOT NULL @@ -584,3 +584,66 @@ WHERE obligation_value_id = $1; DELETE FROM obligation_triggers WHERE id = $1 RETURNING id; + +-- name: listObligationTriggers :many +SELECT + JSON_STRIP_NULLS( + JSON_BUILD_OBJECT( + 'id', ot.id, + 'obligation_value', JSON_BUILD_OBJECT( + 'id', ov.id, + 'value', ov.value, + 'obligation', JSON_BUILD_OBJECT( + 'id', od.id, + 'name', od.name, + 'namespace', JSON_BUILD_OBJECT( + 'id', n.id, + 'name', n.name, + 'fqn', COALESCE(ns_fqns.fqn, '') + ) + ) + ), + 'action', JSON_BUILD_OBJECT( + 'id', a.id, + 'name', a.name + ), + 'attribute_value', JSON_BUILD_OBJECT( + 'id', av.id, + 'value', av.value, + 'fqn', COALESCE(av_fqns.fqn, '') + ), + 'context', CASE + WHEN ot.client_id IS NOT NULL THEN JSON_BUILD_ARRAY( + JSON_BUILD_OBJECT( + 'pep', JSON_BUILD_OBJECT( + 'client_id', ot.client_id + ) + ) + ) + ELSE '[]'::JSON + END + ) + ) as trigger, + JSON_STRIP_NULLS( + JSON_BUILD_OBJECT( + 'labels', ot.metadata -> 'labels', + 'created_at', ot.created_at, + 'updated_at', ot.updated_at + ) + ) as metadata, + COUNT(*) OVER() as total +FROM obligation_triggers ot +JOIN obligation_values_standard ov ON ot.obligation_value_id = ov.id +JOIN obligation_definitions od ON ov.obligation_definition_id = od.id +JOIN attribute_namespaces n ON od.namespace_id = n.id +LEFT JOIN attribute_fqns ns_fqns ON ns_fqns.namespace_id = n.id AND ns_fqns.attribute_id IS NULL AND ns_fqns.value_id IS NULL +JOIN actions a ON ot.action_id = a.id +JOIN attribute_values av ON ot.attribute_value_id = av.id +LEFT JOIN attribute_fqns av_fqns ON av_fqns.value_id = av.id +WHERE + (NULLIF(@namespace_id::TEXT, '') IS NULL OR od.namespace_id = @namespace_id::UUID) AND + (NULLIF(@namespace_fqn::TEXT, '') IS NULL OR ns_fqns.fqn = @namespace_fqn::VARCHAR) +ORDER BY ot.created_at DESC +LIMIT @limit_ +OFFSET @offset_; + diff --git a/service/policy/obligations/obligations.go b/service/policy/obligations/obligations.go index 08883ec3c..efae842d2 100644 --- a/service/policy/obligations/obligations.go +++ b/service/policy/obligations/obligations.go @@ -2,7 +2,6 @@ package obligations import ( "context" - "errors" "fmt" "log/slog" @@ -416,8 +415,18 @@ func (s *Service) RemoveObligationTrigger(ctx context.Context, req *connect.Requ return connect.NewResponse(rsp), nil } -func (s *Service) ListObligationTriggers(_ context.Context, _ *connect.Request[obligations.ListObligationTriggersRequest]) (*connect.Response[obligations.ListObligationTriggersResponse], error) { - return nil, connect.NewError(connect.CodeUnimplemented, errors.New("listObligationTriggers is not yet implemented")) +func (s *Service) ListObligationTriggers(ctx context.Context, req *connect.Request[obligations.ListObligationTriggersRequest]) (*connect.Response[obligations.ListObligationTriggersResponse], error) { + s.logger.DebugContext(ctx, "listing obligation triggers") + + triggers, pr, err := s.dbClient.ListObligationTriggers(ctx, req.Msg) + if err != nil { + return nil, db.StatusifyError(ctx, s.logger, err, db.ErrTextListRetrievalFailed) + } + rsp := &obligations.ListObligationTriggersResponse{ + Triggers: triggers, + Pagination: pr, + } + return connect.NewResponse(rsp), nil } // func (s *Service) AddObligationFulfiller(_ context.Context, _ *connect.Request[obligations.AddObligationFulfillerRequest]) (*connect.Response[obligations.AddObligationFulfillerResponse], error) { diff --git a/service/policy/obligations/obligations_test.go b/service/policy/obligations/obligations_test.go index a6b8db2c1..d37d4f3d3 100644 --- a/service/policy/obligations/obligations_test.go +++ b/service/policy/obligations/obligations_test.go @@ -1144,3 +1144,105 @@ func Test_UpdateObligationValue_Request(t *testing.T) { }) } } + +func Test_ListObligationTriggers_Request(t *testing.T) { + validUUID := uuid.NewString() + + testCases := []struct { + name string + req *obligations.ListObligationTriggersRequest + expectError bool + errorMessage string + }{ + { + name: "valid - no filters", + req: &obligations.ListObligationTriggersRequest{}, + expectError: false, + }, + { + name: "valid - with namespace_id", + req: &obligations.ListObligationTriggersRequest{ + NamespaceId: validUUID, + }, + expectError: false, + }, + { + name: "valid - with namespace_fqn", + req: &obligations.ListObligationTriggersRequest{ + NamespaceFqn: validFQN1, + }, + expectError: false, + }, + { + name: "valid - with pagination only", + req: &obligations.ListObligationTriggersRequest{ + Pagination: &policy.PageRequest{ + Limit: 10, + Offset: 5, + }, + }, + expectError: false, + }, + { + name: "valid - namespace_id with pagination", + req: &obligations.ListObligationTriggersRequest{ + NamespaceId: validUUID, + Pagination: &policy.PageRequest{ + Limit: 20, + Offset: 0, + }, + }, + expectError: false, + }, + { + name: "valid - namespace_fqn with pagination", + req: &obligations.ListObligationTriggersRequest{ + NamespaceFqn: validFQN1, + Pagination: &policy.PageRequest{ + Limit: 15, + Offset: 10, + }, + }, + expectError: false, + }, + { + name: "invalid namespace_id", + req: &obligations.ListObligationTriggersRequest{ + NamespaceId: invalidUUID, + }, + expectError: true, + errorMessage: errMessageUUID, + }, + { + name: "invalid namespace_fqn", + req: &obligations.ListObligationTriggersRequest{ + NamespaceFqn: invalidFQN, + }, + expectError: true, + errorMessage: errMessageURI, + }, + { + name: "both namespace_id and namespace_fqn", + req: &obligations.ListObligationTriggersRequest{ + NamespaceId: validUUID, + NamespaceFqn: validFQN1, + }, + expectError: true, + errorMessage: errMessageOneOf, + }, + } + + v := getValidator() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := v.Validate(tc.req) + if tc.expectError { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errorMessage) + } else { + require.NoError(t, err) + } + }) + } +}