Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 56 additions & 109 deletions go/apps/api/routes/v2_apis_list_keys/200_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package handler_test
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"net/http"
"testing"
Expand All @@ -15,9 +14,8 @@ import (
"github.com/unkeyed/unkey/go/pkg/hash"
"github.com/unkeyed/unkey/go/pkg/ptr"
"github.com/unkeyed/unkey/go/pkg/testutil"
"github.com/unkeyed/unkey/go/pkg/testutil/seed"
"github.com/unkeyed/unkey/go/pkg/uid"

vaultv1 "github.com/unkeyed/unkey/go/gen/proto/vault/v1"
)

func TestSuccess(t *testing.T) {
Expand Down Expand Up @@ -70,127 +68,75 @@ func TestSuccess(t *testing.T) {
require.NoError(t, err)

// Create test identities
identity1ID := uid.New("identity")
identity1ExternalID := "test_user_1"
err = db.Query.InsertIdentity(ctx, h.DB.RW(), db.InsertIdentityParams{
ID: identity1ID,
ExternalID: identity1ExternalID,
identity1 := h.CreateIdentity(seed.CreateIdentityRequest{
WorkspaceID: workspace.ID,
Environment: "",
CreatedAt: time.Now().UnixMilli(),
ExternalID: identity1ExternalID,
Meta: []byte(`{"role": "admin"}`),
})
require.NoError(t, err)

identity2ID := uid.New("identity")
identity2ExternalID := "test_user_2"
err = db.Query.InsertIdentity(ctx, h.DB.RW(), db.InsertIdentityParams{
ID: identity2ID,
ExternalID: identity2ExternalID,
identity2 := h.CreateIdentity(seed.CreateIdentityRequest{
WorkspaceID: workspace.ID,
Environment: "",
CreatedAt: time.Now().UnixMilli(),
ExternalID: identity2ExternalID,
Meta: []byte(`{"role": "user"}`),
})
require.NoError(t, err)

encryptedKeysMap := make(map[string]struct{})
// Create test keys with various configurations
testKeys := []struct {
id string
start string
name string
identityID *string
meta map[string]interface{}
expires *time.Time
enabled bool
}{
{
id: uid.New("key"),
start: "test_key1_",
name: "Test Key 1",
identityID: &identity1ID,
meta: map[string]interface{}{"env": "production", "team": "backend"},
enabled: true,
},
{
id: uid.New("key"),
start: "test_key2_",
name: "Test Key 2",
identityID: &identity1ID,
meta: map[string]interface{}{"env": "staging"},
enabled: true,
},
{
id: uid.New("key"),
start: "test_key3_",
name: "Test Key 3",
identityID: &identity2ID,
meta: map[string]interface{}{"env": "development"},
enabled: true,
},
{
id: uid.New("key"),
start: "test_key4_",
name: "Test Key 4 (No Identity)",
enabled: true,
},
{
id: uid.New("key"),
start: "test_key5_",
name: "Test Key 5 (Disabled)",
enabled: false,
},
}

for _, keyData := range testKeys {
metaBytes := []byte("{}")
if keyData.meta != nil {
metaBytes, _ = json.Marshal(keyData.meta)
}

key := keyData.start + uid.New("")
// Track encrypted keys for verification
encryptedKeys := make(map[string]string) // keyID -> plaintext

insertParams := db.InsertKeyParams{
ID: keyData.id,
KeySpaceID: keySpaceID,
Hash: hash.Sha256(key),
Start: keyData.start,
WorkspaceID: workspace.ID,
ForWorkspaceID: sql.NullString{Valid: false},
Name: sql.NullString{Valid: true, String: keyData.name},
Meta: sql.NullString{Valid: true, String: string(metaBytes)},
Expires: sql.NullTime{Valid: false},
CreatedAtM: time.Now().UnixMilli(),
Enabled: keyData.enabled,
RemainingRequests: sql.NullInt32{Valid: false},
}
// Key 1: identity1, production metadata
key1 := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: keySpaceID,
Name: ptr.P("Test Key 1"),
IdentityID: ptr.P(identity1.ID),
Meta: ptr.P(`{"env": "production", "team": "backend"}`),
Recoverable: true,
})
encryptedKeys[key1.KeyID] = key1.Key

if keyData.identityID != nil {
insertParams.IdentityID = sql.NullString{Valid: true, String: *keyData.identityID}
} else {
insertParams.IdentityID = sql.NullString{Valid: false}
}
// Key 2: identity1, staging metadata
key2 := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: keySpaceID,
Name: ptr.P("Test Key 2"),
IdentityID: ptr.P(identity1.ID),
Meta: ptr.P(`{"env": "staging"}`),
Recoverable: true,
})
encryptedKeys[key2.KeyID] = key2.Key

err := db.Query.InsertKey(ctx, h.DB.RW(), insertParams)
require.NoError(t, err)
// Key 3: identity2, development metadata
key3 := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: keySpaceID,
Name: ptr.P("Test Key 3"),
IdentityID: ptr.P(identity2.ID),
Meta: ptr.P(`{"env": "development"}`),
Recoverable: true,
})
encryptedKeys[key3.KeyID] = key3.Key

encryption, err := h.Vault.Encrypt(ctx, &vaultv1.EncryptRequest{
Keyring: h.Resources().UserWorkspace.ID,
Data: key,
})
require.NoError(t, err)
// Key 4: no identity
key4 := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: keySpaceID,
Name: ptr.P("Test Key 4 (No Identity)"),
Recoverable: true,
})
encryptedKeys[key4.KeyID] = key4.Key

err = db.Query.InsertKeyEncryption(ctx, h.DB.RW(), db.InsertKeyEncryptionParams{
WorkspaceID: h.Resources().UserWorkspace.ID,
KeyID: keyData.id,
CreatedAt: time.Now().UnixMilli(),
Encrypted: encryption.GetEncrypted(),
EncryptionKeyID: encryption.GetKeyId(),
})
require.NoError(t, err)
encryptedKeysMap[keyData.id] = struct{}{}
}
// Key 5: disabled
key5 := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: keySpaceID,
Name: ptr.P("Test Key 5 (Disabled)"),
Disabled: true,
Recoverable: true,
})
encryptedKeys[key5.KeyID] = key5.Key

// Set up request headers
headers := http.Header{
Expand Down Expand Up @@ -604,12 +550,13 @@ func TestSuccess(t *testing.T) {
require.NotNil(t, res.Body.Data)

for _, key := range res.Body.Data {
_, exists := encryptedKeysMap[key.KeyId]
expectedPlaintext, exists := encryptedKeys[key.KeyId]
if !exists {
continue
}

require.NotEmpty(t, key.Plaintext, "Key should be decrypted and have plaintext")
require.Equal(t, expectedPlaintext, key.Plaintext, "Key should be decrypted and have correct plaintext")
}
})
}
33 changes: 30 additions & 3 deletions go/apps/api/routes/v2_apis_list_keys/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package handler

import (
"context"
"database/sql"
"net/http"

"github.com/oapi-codegen/nullable"
Expand Down Expand Up @@ -161,9 +162,35 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
limit := ptr.SafeDeref(req.Limit, 100)
cursor := ptr.SafeDeref(req.Cursor, "")

var identityFilter string
// Resolve identity ID if external_id filter is provided
var identityID sql.NullString
if req.ExternalId != nil && *req.ExternalId != "" {
identityFilter = *req.ExternalId
identity, identityErr := db.Query.FindIdentityByExternalID(ctx, h.DB.RO(), db.FindIdentityByExternalIDParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
ExternalID: *req.ExternalId,
Deleted: false,
})
if identityErr != nil {
if db.IsNotFound(identityErr) {
// Identity doesn't exist, return empty result set
return s.JSON(http.StatusOK, Response{
Meta: openapi.Meta{
RequestId: s.RequestID(),
},
Data: []openapi.KeyResponseData{},
Pagination: &openapi.Pagination{
Cursor: nil,
HasMore: false,
},
})
}
return fault.Wrap(identityErr,
fault.Code(codes.App.Internal.ServiceUnavailable.URN()),
fault.Internal("database error"),
fault.Public("Failed to retrieve identity."),
)
}
identityID = sql.NullString{String: identity.ID, Valid: true}
}

// Query keys by key_auth_id instead of api_id
Expand All @@ -173,7 +200,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
db.ListLiveKeysByKeySpaceIDParams{
KeySpaceID: api.KeyAuthID.String,
IDCursor: cursor,
Identity: identityFilter,
IdentityID: identityID,
Limit: int32(limit + 1), // nolint:gosec
},
)
Expand Down
18 changes: 8 additions & 10 deletions go/apps/api/routes/v2_identities_delete_identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,27 +71,25 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
return err
}

results, err := db.Query.FindIdentityWithRatelimits(ctx, h.DB.RO(), db.FindIdentityWithRatelimitsParams{
identity, err := db.Query.FindIdentity(ctx, h.DB.RO(), db.FindIdentityParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
Identity: req.Identity,
Deleted: false,
})
if err != nil {
if db.IsNotFound(err) {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"), fault.Public("This identity does not exist."),
)
}

return fault.Wrap(err,
fault.Code(codes.App.Internal.ServiceUnavailable.URN()),
fault.Internal("database failed to find the identity"), fault.Public("Error finding the identity."),
)
}

if len(results) == 0 {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"), fault.Public("This identity does not exist."),
)
}

identity := results[0]

// Parse ratelimits JSON
var ratelimits []db.RatelimitInfo
if ratelimitBytes, ok := identity.Ratelimits.([]byte); ok && ratelimitBytes != nil {
Expand Down
19 changes: 8 additions & 11 deletions go/apps/api/routes/v2_identities_get_identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,28 +50,25 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
return err
}

results, err := db.Query.FindIdentityWithRatelimits(ctx, h.DB.RO(), db.FindIdentityWithRatelimitsParams{
identity, err := db.Query.FindIdentity(ctx, h.DB.RO(), db.FindIdentityParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
Identity: req.Identity,
Deleted: false,
})
if err != nil {
if db.IsNotFound(err) {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"), fault.Public("This identity does not exist."),
)
}

return fault.Wrap(err,
fault.Internal("unable to find identity"),
fault.Public("We're unable to retrieve the identity."),
)
}

if len(results) == 0 {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"),
fault.Public("This identity does not exist."),
)
}

identity := results[0]

// Parse ratelimits JSON
ratelimits, err := db.UnmarshalNullableJSONTo[[]db.RatelimitInfo](identity.Ratelimits)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ func TestNotFound(t *testing.T) {
res := testutil.CallRoute[handler.Request, openapi.NotFoundErrorResponse](h, route, headers, req)
require.Equal(t, http.StatusNotFound, res.Status, "expected 404, got: %d", res.Status)
require.Equal(t, "https://unkey.com/docs/errors/unkey/data/identity_not_found", res.Body.Error.Type)
require.Equal(t, "Identity not found in this workspace", res.Body.Error.Detail)
require.Equal(t, "This identity does not exist.", res.Body.Error.Detail)
require.Equal(t, http.StatusNotFound, res.Body.Error.Status)
require.Equal(t, "Not Found", res.Body.Error.Title)
require.NotEmpty(t, res.Body.Meta.RequestId)
Expand Down
22 changes: 9 additions & 13 deletions go/apps/api/routes/v2_identities_update_identity/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,37 +112,33 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}
}

// Use UNION query to find identity + ratelimits in one query (fast!)
results, err := db.Query.FindIdentityWithRatelimits(ctx, h.DB.RO(), db.FindIdentityWithRatelimitsParams{
identityRow, err := db.Query.FindIdentity(ctx, h.DB.RO(), db.FindIdentityParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
Identity: req.Identity,
Deleted: false,
})
if err != nil {
if db.IsNotFound(err) {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"), fault.Public("This identity does not exist."),
)
}

return fault.Wrap(err,
fault.Internal("unable to find identity"),
fault.Public("We're unable to retrieve the identity."),
)
}

if len(results) == 0 {
return fault.New("identity not found",
fault.Code(codes.Data.Identity.NotFound.URN()),
fault.Internal("identity not found"),
fault.Public("Identity not found in this workspace"),
)
}

identityRow := results[0]

// Parse existing ratelimits from JSON
var existingRatelimits []db.RatelimitInfo
if ratelimitBytes, ok := identityRow.Ratelimits.([]byte); ok && ratelimitBytes != nil {
_ = json.Unmarshal(ratelimitBytes, &existingRatelimits) // Ignore error, default to empty array
}

type txResult struct {
identity db.FindIdentityWithRatelimitsRow
identity db.FindIdentityRow
finalRatelimits []openapi.RatelimitResponse
}

Expand Down
Loading
Loading