Skip to content
169 changes: 143 additions & 26 deletions service/integration/kas_registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/opentdf/platform/protocol/go/policy"
"github.com/opentdf/platform/protocol/go/policy/attributes"
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
"github.com/opentdf/platform/protocol/go/policy/namespaces"
"github.com/opentdf/platform/service/internal/fixtures"
"github.com/opentdf/platform/service/pkg/db"

Expand Down Expand Up @@ -432,33 +433,82 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrantsByKasId() {
s.Require().NoError(err)
s.NotNil(createdAttr)

fixtureKAS := s.f.GetKasRegistryKey("key_access_server_2")
// create a value
val := &attributes.CreateAttributeValueRequest{
AttributeId: createdAttr.GetId(),
Value: "value2",
}
createdVal, err := s.db.PolicyClient.CreateAttributeValue(s.ctx, createdAttr.GetId(), val)
s.Require().NoError(err)
s.NotNil(createdVal)

// add a KAS to the attribute
firstKAS, err := s.db.PolicyClient.CreateKeyAccessServer(s.ctx, &kasregistry.CreateKeyAccessServerRequest{
Uri: "https://firstkas.com/kas/uri",
PublicKey: &policy.PublicKey{
PublicKey: &policy.PublicKey_Local{Local: "public"},
},
})
s.Require().NoError(err)
s.NotNil(firstKAS)
firstKAS, _ = s.db.PolicyClient.GetKeyAccessServer(s.ctx, firstKAS.GetId())

otherKAS, err := s.db.PolicyClient.CreateKeyAccessServer(s.ctx, &kasregistry.CreateKeyAccessServerRequest{
Uri: "https://otherkas.com/kas/uri",
PublicKey: &policy.PublicKey{
PublicKey: &policy.PublicKey_Local{Local: "public"},
},
})
s.Require().NoError(err)
otherKAS, _ = s.db.PolicyClient.GetKeyAccessServer(s.ctx, otherKAS.GetId())

// assign a KAS to the attribute
aKas := &attributes.AttributeKeyAccessServer{
AttributeId: createdAttr.GetId(),
KeyAccessServerId: fixtureKAS.ID,
KeyAccessServerId: firstKAS.GetId(),
}
createdGrant, err := s.db.PolicyClient.AssignKeyAccessServerToAttribute(s.ctx, aKas)
s.Require().NoError(err)
s.NotNil(createdGrant)

// assign a KAS to the value
bKas := &attributes.ValueKeyAccessServer{
ValueId: createdVal.GetId(),
KeyAccessServerId: otherKAS.GetId(),
}
valGrant, err := s.db.PolicyClient.AssignKeyAccessServerToValue(s.ctx, bKas)
s.Require().NoError(err)
s.NotNil(valGrant)

// list grants by KAS ID
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrantsByKasId(s.ctx, fixtureKAS.ID)
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, firstKAS.GetId(), "")
s.Require().NoError(err)
s.NotNil(listedGrants)
s.GreaterOrEqual(len(listedGrants), 1)
for _, g := range listedGrants {
s.Equal(fixtureKAS.ID, g.KasID)
s.Equal(fixtureKAS.URI, g.KasUri)
}
s.Len(listedGrants, 1)
g := listedGrants[0]
s.Equal(firstKAS.GetId(), g.GetKeyAccessServer().GetId())
s.Equal(firstKAS.GetUri(), g.GetKeyAccessServer().GetUri())
s.Len(g.GetAttributeGrants(), 1)
s.Empty(g.GetValueGrants())
s.Empty(g.GetNamespaceGrants())

// list grants by the other KAS ID
listedGrants, err = s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, otherKAS.GetId(), "")
s.Require().NoError(err)
s.NotNil(listedGrants)
s.Len(listedGrants, 1)
g = listedGrants[0]
s.Equal(otherKAS.GetId(), g.GetKeyAccessServer().GetId())
s.Equal(otherKAS.GetUri(), g.GetKeyAccessServer().GetUri())
s.Empty(g.GetAttributeGrants())
s.Len(g.GetValueGrants(), 1)
s.Empty(g.GetNamespaceGrants())
}

func (s *KasRegistrySuite) Test_ListKeyAccessServerGrantsByKasId_NoResultsIfNotFound() {
// list grants by KAS ID
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrantsByKasId(s.ctx, nonExistentKasRegistryID)
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, nonExistentKasRegistryID, "")
s.Require().NoError(err)
s.Nil(listedGrants)
s.Empty(listedGrants)
}

func (s *KasRegistrySuite) Test_ListKeyAccessServerGrantsByKasUri() {
Expand All @@ -484,21 +534,22 @@ func (s *KasRegistrySuite) Test_ListKeyAccessServerGrantsByKasUri() {
s.NotNil(createdGrant)

// list grants by KAS URI
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrantsByKasUri(s.ctx, fixtureKAS.URI)
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", fixtureKAS.URI)

s.Require().NoError(err)
s.NotNil(listedGrants)
s.GreaterOrEqual(len(listedGrants), 1)
for _, g := range listedGrants {
s.Equal(fixtureKAS.ID, g.KasID)
s.Equal(fixtureKAS.URI, g.KasUri)
s.Equal(fixtureKAS.ID, g.GetKeyAccessServer().GetId())
s.Equal(fixtureKAS.URI, g.GetKeyAccessServer().GetUri())
}
}

func (s *KasRegistrySuite) Test_ListKeyAccessServerGrantsByKasUri_NoResultsIfNotFound() {
// list grants by KAS ID
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrantsByKasUri(s.ctx, "https://notfound.com/kas/uri")
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "https://notfound.com/kas/uri")
s.Require().NoError(err)
s.Nil(listedGrants)
s.Empty(listedGrants)
}

func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() {
Expand All @@ -509,42 +560,108 @@ func (s *KasRegistrySuite) Test_ListAllKeyAccessServerGrants() {
PublicKey: &policy.PublicKey_Local{Local: "public"},
},
}
createdKAS, err := s.db.PolicyClient.CreateKeyAccessServer(s.ctx, kas)
firstKAS, err := s.db.PolicyClient.CreateKeyAccessServer(s.ctx, kas)
s.Require().NoError(err)
s.NotNil(createdKAS)
s.NotNil(firstKAS)

// create an attribute
attr := &attributes.CreateAttributeRequest{
Name: "test__list_all_key_access_server_grants",
NamespaceId: fixtureNamespaceID,
Rule: policy.AttributeRuleTypeEnum_ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF,
Values: []string{"value1"},
}
createdAttr, err := s.db.PolicyClient.CreateAttribute(s.ctx, attr)
s.Require().NoError(err)
s.NotNil(createdAttr)

// add a KAS to the attribute
got, err := s.db.PolicyClient.GetAttribute(s.ctx, createdAttr.GetId())
s.Require().NoError(err)
s.NotNil(got)
value := got.GetValues()[0]

// add first KAS to the attribute
aKas := &attributes.AttributeKeyAccessServer{
AttributeId: createdAttr.GetId(),
KeyAccessServerId: createdKAS.GetId(),
KeyAccessServerId: firstKAS.GetId(),
}
createdGrant, err := s.db.PolicyClient.AssignKeyAccessServerToAttribute(s.ctx, aKas)
s.Require().NoError(err)
s.NotNil(createdGrant)

// add another KAS and grant it to the value
second := &kasregistry.CreateKeyAccessServerRequest{
Uri: "https://listingkasgrants.com/another/kas/uri",
PublicKey: &policy.PublicKey{
PublicKey: &policy.PublicKey_Local{Local: "public"},
},
}
secondKAS, err := s.db.PolicyClient.CreateKeyAccessServer(s.ctx, second)
s.Require().NoError(err)
s.NotNil(secondKAS)

// assign a grant of the second KAS to the value
bKas := &attributes.ValueKeyAccessServer{
ValueId: value.GetId(),
KeyAccessServerId: secondKAS.GetId(),
}
valGrant, err := s.db.PolicyClient.AssignKeyAccessServerToValue(s.ctx, bKas)
s.Require().NoError(err)
s.NotNil(valGrant)

// grant each KAS to the namespace
nsKas := &namespaces.NamespaceKeyAccessServer{
NamespaceId: fixtureNamespaceID,
KeyAccessServerId: firstKAS.GetId(),
}
nsGrant, err := s.db.PolicyClient.AssignKeyAccessServerToNamespace(s.ctx, nsKas)
s.Require().NoError(err)
s.NotNil(nsGrant)

nsAnotherKas := &namespaces.NamespaceKeyAccessServer{
NamespaceId: fixtureNamespaceID,
KeyAccessServerId: secondKAS.GetId(),
}
nsAnotherGrant, err := s.db.PolicyClient.AssignKeyAccessServerToNamespace(s.ctx, nsAnotherKas)
s.Require().NoError(err)
s.NotNil(nsAnotherGrant)

// list all grants
listedGrants, err := s.db.PolicyClient.ListAllKeyAccessServerGrants(s.ctx)
listedGrants, err := s.db.PolicyClient.ListKeyAccessServerGrants(s.ctx, "", "")
s.Require().NoError(err)
s.NotNil(listedGrants)
s.GreaterOrEqual(len(listedGrants), 1)
found := false

for _, g := range listedGrants {
if g.KasID == createdKAS.GetId() {
found = true
break
if g.GetKeyAccessServer().GetId() == firstKAS.GetId() {
// should have expected attribute grant
grantedAttrIDs := make([]string, len(g.GetAttributeGrants()))
for i, a := range g.GetAttributeGrants() {
grantedAttrIDs[i] = a.GetId()
}
s.Contains(grantedAttrIDs, createdAttr.GetId())
// should have expected namespace grant
grantedNsIDs := make([]string, len(g.GetNamespaceGrants()))
for i, n := range g.GetNamespaceGrants() {
grantedNsIDs[i] = n.GetId()
}
s.Contains(grantedNsIDs, fixtureNamespaceID)
}
if g.GetKeyAccessServer().GetId() == secondKAS.GetId() {
// should have expected value grant
grantedValIDs := make([]string, len(g.GetValueGrants()))
for i, v := range g.GetValueGrants() {
grantedValIDs[i] = v.GetId()
}
s.Contains(grantedValIDs, value.GetId())
// should have expected namespace grant
grantedNsIDs := make([]string, len(g.GetNamespaceGrants()))
for i, n := range g.GetNamespaceGrants() {
grantedNsIDs[i] = n.GetId()
}
s.Contains(grantedNsIDs, fixtureNamespaceID)
}
}
s.True(found)
}

func TestKasRegistrySuite(t *testing.T) {
Expand Down
23 changes: 23 additions & 0 deletions service/pkg/db/marshalHelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/opentdf/platform/protocol/go/common"
"github.com/opentdf/platform/protocol/go/policy"
"github.com/opentdf/platform/protocol/go/policy/kasregistry"
"google.golang.org/protobuf/encoding/protojson"
)

Expand Down Expand Up @@ -80,3 +81,25 @@ func KeyAccessServerProtoJSON(keyAccessServerJSON []byte) ([]*policy.KeyAccessSe
}
return keyAccessServers, nil
}

func GrantedPolicyObjectProtoJSON(grantsJSON []byte) ([]*kasregistry.GrantedPolicyObject, error) {
var (
policyObjectGrants []*kasregistry.GrantedPolicyObject
raw []json.RawMessage
)
if grantsJSON == nil {
return nil, nil
}

if err := json.Unmarshal(grantsJSON, &raw); err != nil {
return nil, err
}
for _, r := range raw {
po := kasregistry.GrantedPolicyObject{}
if err := protojson.Unmarshal(r, &po); err != nil {
return nil, err
}
policyObjectGrants = append(policyObjectGrants, &po)
}
return policyObjectGrants, nil
}
45 changes: 45 additions & 0 deletions service/policy/db/key_access_server_registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package db

import (
"context"
"fmt"

"github.com/jackc/pgx/v5/pgtype"
"github.com/opentdf/platform/protocol/go/common"
Expand Down Expand Up @@ -157,3 +158,47 @@ func (c PolicyDBClient) DeleteKeyAccessServer(ctx context.Context, id string) (*
Id: id,
}, nil
}

func (c PolicyDBClient) ListKeyAccessServerGrants(ctx context.Context, kasID string, kasURI string) ([]*kasregistry.KeyAccessServerGrants, error) {
params := ListKeyAccessServerGrantsParams{
KasID: kasID,
KasUri: kasURI,
}
listRows, err := c.Queries.ListKeyAccessServerGrants(ctx, params)
if err != nil {
return nil, db.WrapIfKnownInvalidQueryErr(err)
}

grants := make([]*kasregistry.KeyAccessServerGrants, len(listRows))
for i, grant := range listRows {
pubKey := new(policy.PublicKey)
if err := protojson.Unmarshal(grant.KasPublicKey, pubKey); err != nil {
return nil, fmt.Errorf("failed to unmarshal KAS public key: %w", err)
}
kas := &policy.KeyAccessServer{
Id: grant.KasID,
Uri: grant.KasUri,
PublicKey: pubKey,
}
attrGrants, err := db.GrantedPolicyObjectProtoJSON(grant.AttributesGrants)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal attribute grants: %w", err)
}
valGrants, err := db.GrantedPolicyObjectProtoJSON(grant.ValuesGrants)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal value grants: %w", err)
}
namespaceGrants, err := db.GrantedPolicyObjectProtoJSON(grant.NamespaceGrants)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal namespace grants: %w", err)
}
grants[i] = &kasregistry.KeyAccessServerGrants{
KeyAccessServer: kas,
AttributeGrants: attrGrants,
ValueGrants: valGrants,
NamespaceGrants: namespaceGrants,
}
}

return grants, nil
}
Loading