diff --git a/integration/attribute_values_test.go b/integration/attribute_values_test.go index 22b079eb42..ea525f6d20 100644 --- a/integration/attribute_values_test.go +++ b/integration/attribute_values_test.go @@ -295,8 +295,8 @@ func (s *AttributeValuesSuite) Test_RemoveKeyAccessServerFromValue_Returns_Error func (s *AttributeValuesSuite) Test_RemoveKeyAccessServerFromValue_Returns_Success_When_Value_And_KeyAccessServer_Exist() { v := &attributes.ValueKeyAccessServer{ - ValueId: fixtures.GetAttributeValueKey("example.net/attr/attr1/value/value1").Id, - KeyAccessServerId: fixtureKeyAccessServerId, + ValueId: fixtures.GetAttributeValueKey("example.com/attr/attr1/value/value1").Id, + KeyAccessServerId: fixtures.GetKasRegistryKey("key_access_server_1").Id, } resp, err := s.db.Client.RemoveKeyAccessServerFromValue(s.ctx, v) diff --git a/integration/attributes_test.go b/integration/attributes_test.go index a83b518b0c..ba34a62bbe 100644 --- a/integration/attributes_test.go +++ b/integration/attributes_test.go @@ -363,8 +363,8 @@ func (s *AttributesSuite) Test_RemoveKeyAccessServerFromAttribute_Returns_Error_ func (s *AttributesSuite) Test_RemoveKeyAccessServerFromAttribute_Returns_Success_When_Attribute_And_KeyAccessServer_Exist() { aKas := &attributes.AttributeKeyAccessServer{ - AttributeId: fixtures.GetAttributeKey("example.com/attr/attr2").Id, - KeyAccessServerId: fixtureKeyAccessServerId, + AttributeId: fixtures.GetAttributeKey("example.com/attr/attr1").Id, + KeyAccessServerId: fixtures.GetKasRegistryKey("key_access_server_1").Id, } resp, err := s.db.Client.RemoveKeyAccessServerFromAttribute(s.ctx, aKas) diff --git a/integration/fixtures.go b/integration/fixtures.go index b590bb787d..85b75316a0 100644 --- a/integration/fixtures.go +++ b/integration/fixtures.go @@ -191,7 +191,7 @@ func (f *Fixtures) Provision() { slog.Info("📦 provisioning attribute key access server data") akas := f.provisionAttributeKeyAccessServer() slog.Info("📦 provisioning attribute value key access server data") - akas = f.provisionAttributeValueKeyAccessServer() + avkas := f.provisionAttributeValueKeyAccessServer() slog.Info("📦 provisioned fixtures data", slog.Int64("namespaces", n), @@ -201,7 +201,7 @@ func (f *Fixtures) Provision() { slog.Int64("resource_mappings", rM), slog.Int64("kas_registry", kas), slog.Int64("attribute_key_access_server", akas), - slog.Int64("attribute_value_key_access_server", akas), + slog.Int64("attribute_value_key_access_server", avkas), ) } diff --git a/integration/fixtures.yaml b/integration/fixtures.yaml index 0d803b44f4..7039f4bfaf 100644 --- a/integration/fixtures.yaml +++ b/integration/fixtures.yaml @@ -78,6 +78,8 @@ attributes: attribute_key_access_servers: - attribute_id: 00000000-0000-0000-0000-000000000000 key_access_server_id: 00000000-0000-0000-0000-000000000000 + - attribute_id: 00000000-0000-0000-0000-000000000000 + key_access_server_id: 00000000-0000-0000-0000-000000000001 ## # Attribute Values @@ -126,6 +128,8 @@ attribute_values: attribute_value_key_access_servers: - value_id: 00000000-0000-0000-0000-000000000000 key_access_server_id: 00000000-0000-0000-0000-000000000000 + - value_id: 00000000-0000-0000-0000-000000000000 + key_access_server_id: 00000000-0000-0000-0000-000000000001 ## # Subject Mappings diff --git a/internal/db/attribute_values.go b/internal/db/attribute_values.go index 9eb8ea2332..28bf28ac03 100644 --- a/internal/db/attribute_values.go +++ b/internal/db/attribute_values.go @@ -14,11 +14,11 @@ var AttributeValueTable = tableName(TableAttributeValues) func attributeValueHydrateItem(row pgx.Row) (*attributes.Value, error) { var ( - id string - value string - members []string - metadataJson []byte - attributeId string + id string + value string + members []string + metadataJson []byte + attributeId string ) if err := row.Scan(&id, &value, &members, &metadataJson, &attributeId); err != nil { return nil, err @@ -32,10 +32,10 @@ func attributeValueHydrateItem(row pgx.Row) (*attributes.Value, error) { } v := &attributes.Value{ - Id: id, - Value: value, - Members: members, - Metadata: m, + Id: id, + Value: value, + Members: members, + Metadata: m, AttributeId: attributeId, } return v, nil @@ -256,6 +256,7 @@ func removeKeyAccessServerFromValueSql(valueID, keyAccessServerID string) (strin return newStatementBuilder(). Delete(t.Name()). Where(sq.Eq{"attribute_value_id": valueID, "key_access_server_id": keyAccessServerID}). + Suffix("IS TRUE RETURNING *"). ToSql() } @@ -265,7 +266,7 @@ func (c Client) RemoveKeyAccessServerFromValue(ctx context.Context, k *attribute return nil, err } - if err := c.exec(ctx, sql, args, err); err != nil { + if _, err := c.queryCount(ctx, sql, args); err != nil { return nil, err } diff --git a/internal/db/attributes.go b/internal/db/attributes.go index febbabf1d4..c204bf416c 100644 --- a/internal/db/attributes.go +++ b/internal/db/attributes.go @@ -417,13 +417,17 @@ func removeKeyAccessServerFromAttributeSql(attributeID, keyAccessServerID string return newStatementBuilder(). Delete(t.Name()). Where(sq.Eq{"attribute_definition_id": attributeID, "key_access_server_id": keyAccessServerID}). + Suffix("IS TRUE RETURNING *"). ToSql() } func (c Client) RemoveKeyAccessServerFromAttribute(ctx context.Context, k *attributes.AttributeKeyAccessServer) (*attributes.AttributeKeyAccessServer, error) { sql, args, err := removeKeyAccessServerFromAttributeSql(k.AttributeId, k.KeyAccessServerId) + if err != nil { + return nil, err + } - if err := c.exec(ctx, sql, args, err); err != nil { + if _, err := c.queryCount(ctx, sql, args); err != nil { return nil, err } diff --git a/internal/db/db.go b/internal/db/db.go index de0134bd47..c09667c27b 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -152,6 +152,27 @@ func (c Client) query(ctx context.Context, sql string, args []interface{}, err e return r, WrapIfKnownInvalidQueryErr(e) } +func (c Client) queryCount(ctx context.Context, sql string, args []interface{}) (int, error) { + rows, err := c.query(ctx, sql, args, nil) + if err != nil { + return 0, err + } + defer rows.Close() + + count := 0 + for rows.Next() { + if _, err := rows.Values(); err != nil { + return 0, err + } + count++ + } + if count == 0 { + return 0, pgx.ErrNoRows + } + + return count, nil +} + // Common function for all exec calls func (c Client) exec(ctx context.Context, sql string, args []interface{}, err error) error { slog.Debug("sql", slog.String("sql", sql), slog.Any("args", args))