Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration/attribute_values_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,8 @@ func (s *AttributeValuesSuite) Test_RemoveKeyAccessServerFromValue_Returns_Error

func (s *AttributeValuesSuite) Test_RemoveKeyAccessServerFromValue_Returns_Success_When_Value_And_KeyAccessServer_Exist() {
v := &attributes.ValueKeyAccessServer{
ValueId: fixtures.GetAttributeValueKey("example.net/attr/attr1/value/value1").Id,
KeyAccessServerId: fixtureKeyAccessServerId,
ValueId: fixtures.GetAttributeValueKey("example.com/attr/attr1/value/value1").Id,
KeyAccessServerId: fixtures.GetKasRegistryKey("key_access_server_1").Id,
}

resp, err := s.db.Client.RemoveKeyAccessServerFromValue(s.ctx, v)
Expand Down
4 changes: 2 additions & 2 deletions integration/attributes_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,8 @@ func (s *AttributesSuite) Test_RemoveKeyAccessServerFromAttribute_Returns_Error_

func (s *AttributesSuite) Test_RemoveKeyAccessServerFromAttribute_Returns_Success_When_Attribute_And_KeyAccessServer_Exist() {
aKas := &attributes.AttributeKeyAccessServer{
AttributeId: fixtures.GetAttributeKey("example.com/attr/attr2").Id,
KeyAccessServerId: fixtureKeyAccessServerId,
AttributeId: fixtures.GetAttributeKey("example.com/attr/attr1").Id,
KeyAccessServerId: fixtures.GetKasRegistryKey("key_access_server_1").Id,
}
resp, err := s.db.Client.RemoveKeyAccessServerFromAttribute(s.ctx, aKas)

Expand Down
4 changes: 2 additions & 2 deletions integration/fixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (f *Fixtures) Provision() {
slog.Info("📦 provisioning attribute key access server data")
akas := f.provisionAttributeKeyAccessServer()
slog.Info("📦 provisioning attribute value key access server data")
akas = f.provisionAttributeValueKeyAccessServer()
avkas := f.provisionAttributeValueKeyAccessServer()

slog.Info("📦 provisioned fixtures data",
slog.Int64("namespaces", n),
Expand All @@ -201,7 +201,7 @@ func (f *Fixtures) Provision() {
slog.Int64("resource_mappings", rM),
slog.Int64("kas_registry", kas),
slog.Int64("attribute_key_access_server", akas),
slog.Int64("attribute_value_key_access_server", akas),
slog.Int64("attribute_value_key_access_server", avkas),
)
}

Expand Down
4 changes: 4 additions & 0 deletions integration/fixtures.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ attributes:
attribute_key_access_servers:
- attribute_id: 00000000-0000-0000-0000-000000000000
key_access_server_id: 00000000-0000-0000-0000-000000000000
- attribute_id: 00000000-0000-0000-0000-000000000000
key_access_server_id: 00000000-0000-0000-0000-000000000001

##
# Attribute Values
Expand Down Expand Up @@ -126,6 +128,8 @@ attribute_values:
attribute_value_key_access_servers:
- value_id: 00000000-0000-0000-0000-000000000000
key_access_server_id: 00000000-0000-0000-0000-000000000000
- value_id: 00000000-0000-0000-0000-000000000000
key_access_server_id: 00000000-0000-0000-0000-000000000001

##
# Subject Mappings
Expand Down
21 changes: 11 additions & 10 deletions internal/db/attribute_values.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ var AttributeValueTable = tableName(TableAttributeValues)

func attributeValueHydrateItem(row pgx.Row) (*attributes.Value, error) {
var (
id string
value string
members []string
metadataJson []byte
attributeId string
id string
value string
members []string
metadataJson []byte
attributeId string
)
if err := row.Scan(&id, &value, &members, &metadataJson, &attributeId); err != nil {
return nil, err
Expand All @@ -32,10 +32,10 @@ func attributeValueHydrateItem(row pgx.Row) (*attributes.Value, error) {
}

v := &attributes.Value{
Id: id,
Value: value,
Members: members,
Metadata: m,
Id: id,
Value: value,
Members: members,
Metadata: m,
AttributeId: attributeId,
}
return v, nil
Expand Down Expand Up @@ -256,6 +256,7 @@ func removeKeyAccessServerFromValueSql(valueID, keyAccessServerID string) (strin
return newStatementBuilder().
Delete(t.Name()).
Where(sq.Eq{"attribute_value_id": valueID, "key_access_server_id": keyAccessServerID}).
Suffix("IS TRUE RETURNING *").
ToSql()
}

Expand All @@ -265,7 +266,7 @@ func (c Client) RemoveKeyAccessServerFromValue(ctx context.Context, k *attribute
return nil, err
}

if err := c.exec(ctx, sql, args, err); err != nil {
if _, err := c.queryCount(ctx, sql, args); err != nil {
return nil, err
}

Expand Down
6 changes: 5 additions & 1 deletion internal/db/attributes.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,13 +417,17 @@ func removeKeyAccessServerFromAttributeSql(attributeID, keyAccessServerID string
return newStatementBuilder().
Delete(t.Name()).
Where(sq.Eq{"attribute_definition_id": attributeID, "key_access_server_id": keyAccessServerID}).
Suffix("IS TRUE RETURNING *").
ToSql()
}

func (c Client) RemoveKeyAccessServerFromAttribute(ctx context.Context, k *attributes.AttributeKeyAccessServer) (*attributes.AttributeKeyAccessServer, error) {
sql, args, err := removeKeyAccessServerFromAttributeSql(k.AttributeId, k.KeyAccessServerId)
if err != nil {
return nil, err
}

if err := c.exec(ctx, sql, args, err); err != nil {
if _, err := c.queryCount(ctx, sql, args); err != nil {
return nil, err
}

Expand Down
21 changes: 21 additions & 0 deletions internal/db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,27 @@ func (c Client) query(ctx context.Context, sql string, args []interface{}, err e
return r, WrapIfKnownInvalidQueryErr(e)
}

func (c Client) queryCount(ctx context.Context, sql string, args []interface{}) (int, error) {
rows, err := c.query(ctx, sql, args, nil)
if err != nil {
return 0, err
}
defer rows.Close()

count := 0
for rows.Next() {
if _, err := rows.Values(); err != nil {
return 0, err
}
count++
}
if count == 0 {
return 0, pgx.ErrNoRows
}

return count, nil
}

// Common function for all exec calls
func (c Client) exec(ctx context.Context, sql string, args []interface{}, err error) error {
slog.Debug("sql", slog.String("sql", sql), slog.Any("args", args))
Expand Down