Skip to content

Commit 21fd1d4

Browse files
Feat custom user handle (#1978)
Add a custom user handle to a webauthn credential --------- Co-authored-by: bjoern-m <[email protected]>
1 parent e172e05 commit 21fd1d4

9 files changed

+135
-57
lines changed

backend/flow_api/flow/shared/hook_issue_session.go

+3-4
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ func (h IssueSession) Execute(c flowpilot.HookExecutionContext) error {
2929
return errors.New("user_id not found in stash")
3030
}
3131

32-
emails, err := deps.Persister.GetEmailPersisterWithConnection(deps.Tx).FindByUserId(userId)
32+
userModel, err := deps.Persister.GetUserPersisterWithConnection(deps.Tx).Get(userId)
3333
if err != nil {
34-
return fmt.Errorf("failed to fetch emails from db: %w", err)
34+
return fmt.Errorf("failed to fetch user from db: %w", err)
3535
}
3636

3737
var emailDTO *dto.EmailJwt
38-
39-
if email := emails.GetPrimary(); email != nil {
38+
if email := userModel.Emails.GetPrimary(); email != nil {
4039
emailDTO = dto.JwtFromEmailModel(email)
4140
}
4241

backend/flow_api/services/webauthn.go

+51-35
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package services
22

33
import (
4-
"encoding/base64"
54
"errors"
65
"fmt"
76
"github.com/go-webauthn/webauthn/protocol"
@@ -178,45 +177,38 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
178177
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
179178
}
180179

181-
sessionDataModel, err := s.persister.GetWebauthnSessionDataPersister().Get(p.SessionDataID)
180+
sessionDataModel, err := s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Get(p.SessionDataID)
182181
if err != nil {
183182
return nil, fmt.Errorf("failed to get session data from db: %w", err)
184183
}
185184

186-
var userID uuid.UUID
187-
if p.IsMFA {
188-
userID = sessionDataModel.UserId
189-
} else {
190-
userID, err = uuid.FromBytes(credentialAssertionData.Response.UserHandle)
191-
if err != nil {
192-
return nil, fmt.Errorf("failed to parse user id from user handle: %w", err)
193-
}
194-
}
195-
196-
userModel, err := s.persister.GetUserPersister().Get(userID)
185+
credentialModel, err := s.persister.GetWebauthnCredentialPersister().Get(credentialAssertionData.ID)
197186
if err != nil {
198-
return nil, fmt.Errorf("failed to fetch user from db: %w", err)
187+
return nil, fmt.Errorf("failed to get webauthncredential from db: %w", err)
199188
}
200189

201-
if userModel == nil {
202-
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
190+
if credentialModel == nil {
191+
return nil, ErrInvalidWebauthnCredential
203192
}
204193

205-
cred := userModel.GetWebauthnCredentialById(credentialAssertionData.ID)
206-
if cred != nil && (!p.IsMFA && cred.MFAOnly) {
194+
if !p.IsMFA && credentialModel.MFAOnly {
207195
return nil, ErrInvalidWebauthnCredentialMFAOnly
208196
}
209197

198+
webAuthnUser, userModel, err := s.GetWebAuthnUser(p.Tx, *credentialModel)
199+
if err != nil {
200+
return nil, err
201+
}
202+
210203
discoverableUserHandler := func(rawID, userHandle []byte) (webauthn.User, error) {
211-
return userModel, nil
204+
return webAuthnUser, nil
212205
}
213206

214207
sessionData := sessionDataModel.ToSessionData()
215-
var credential *webauthn.Credential
216208
if p.IsMFA {
217-
credential, err = s.cfg.Webauthn.Handler.ValidateLogin(userModel, *sessionData, credentialAssertionData)
209+
_, err = s.cfg.Webauthn.Handler.ValidateLogin(webAuthnUser, *sessionData, credentialAssertionData)
218210
} else {
219-
credential, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
211+
_, err = s.cfg.Webauthn.Handler.ValidateDiscoverableLogin(
220212
discoverableUserHandler,
221213
*sessionData,
222214
credentialAssertionData,
@@ -226,19 +218,16 @@ func (s *webauthnService) VerifyAssertionResponse(p VerifyAssertionResponseParam
226218
return nil, fmt.Errorf("%s: %w", err, ErrInvalidWebauthnCredential)
227219
}
228220

229-
encodedCredentialId := base64.RawURLEncoding.EncodeToString(credential.ID)
230-
if credentialModel := userModel.GetWebauthnCredentialById(encodedCredentialId); credentialModel != nil {
231-
now := time.Now().UTC()
232-
flags := credentialAssertionData.Response.AuthenticatorData.Flags
221+
now := time.Now().UTC()
222+
flags := credentialAssertionData.Response.AuthenticatorData.Flags
233223

234-
credentialModel.LastUsedAt = &now
235-
credentialModel.BackupState = flags.HasBackupState()
236-
credentialModel.BackupEligible = flags.HasBackupEligible()
224+
credentialModel.LastUsedAt = &now
225+
credentialModel.BackupState = flags.HasBackupState()
226+
credentialModel.BackupEligible = flags.HasBackupEligible()
237227

238-
err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
239-
if err != nil {
240-
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
241-
}
228+
err = s.persister.GetWebauthnCredentialPersisterWithConnection(p.Tx).Update(*credentialModel)
229+
if err != nil {
230+
return nil, fmt.Errorf("failed to update webauthn credential: %w", err)
242231
}
243232

244233
err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Delete(*sessionDataModel)
@@ -279,11 +268,10 @@ func (s *webauthnService) generateCreationOptions(p GenerateCreationOptionsParam
279268

280269
err = s.persister.GetWebauthnSessionDataPersisterWithConnection(p.Tx).Create(*sessionDataModel)
281270
if err != nil {
282-
return nil, nil, fmt.Errorf("failed to store session data to the db: %W", err)
271+
return nil, nil, fmt.Errorf("failed to store session data to the db: %w", err)
283272
}
284273

285274
return sessionDataModel, options, nil
286-
287275
}
288276

289277
func (s *webauthnService) GenerateCreationOptionsSecurityKey(p GenerateCreationOptionsParams) (*models.WebauthnSessionData, *protocol.CredentialCreation, error) {
@@ -354,3 +342,31 @@ func (s *webauthnService) VerifyAttestationResponse(p VerifyAttestationResponseP
354342

355343
return credential, nil
356344
}
345+
346+
func (s *webauthnService) GetWebAuthnUser(tx *pop.Connection, credential models.WebauthnCredential) (webauthn.User, *models.User, error) {
347+
user, err := s.persister.GetUserPersisterWithConnection(tx).Get(credential.UserId)
348+
if err != nil {
349+
return nil, nil, fmt.Errorf("failed to fetch user from db: %w", err)
350+
}
351+
if user == nil {
352+
return nil, nil, ErrInvalidWebauthnCredential
353+
}
354+
355+
if credential.UserHandle != nil {
356+
return &webauthnUserWithCustomUserHandle{
357+
CustomUserHandle: []byte(credential.UserHandle.Handle),
358+
User: *user,
359+
}, user, nil
360+
}
361+
362+
return user, user, err
363+
}
364+
365+
type webauthnUserWithCustomUserHandle struct {
366+
models.User
367+
CustomUserHandle []byte
368+
}
369+
370+
func (u *webauthnUserWithCustomUserHandle) WebAuthnID() []byte {
371+
return u.CustomUserHandle
372+
}

backend/handler/webauthn_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ var userId = "ec4ef049-5b88-4321-a173-21b0eff06a04"
332332
type sessionManager struct {
333333
}
334334

335-
func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt) (string, jwt.Token, error) {
335+
func (s sessionManager) GenerateJWT(_ uuid.UUID, _ *dto.EmailJwt, _ ...session.JWTOptions) (string, jwt.Token, error) {
336336
return userId, nil, nil
337337
}
338338

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
drop_foreign_key("webauthn_credentials", "webauthn_credential_user_handle_fkey", {"if_exists": false})
2+
drop_column("webauthn_credentials", "user_handle_id")
3+
drop_table("webauthn_credential_user_handles")
4+
5+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
create_table("webauthn_credential_user_handles") {
2+
t.Column("id", "uuid", {primary: true})
3+
t.Column("user_id", "uuid", {"null": false})
4+
t.Column("handle", "string", {"null": false, "unique": true})
5+
t.Timestamps()
6+
t.Index(["id", "user_id"], {"unique": true})
7+
t.ForeignKey("user_id", {"users": ["id"]}, {"on_delete": "cascade", "on_update": "cascade"})
8+
}
9+
10+
add_column("webauthn_credentials", "user_handle_id", "uuid", { "null": true })
11+
add_foreign_key("webauthn_credentials", "user_handle_id", {"webauthn_credential_user_handles": ["id"]}, {
12+
"on_delete": "set null",
13+
"on_update": "cascade",
14+
})
15+
16+
sql("ALTER TABLE webauthn_credentials ADD CONSTRAINT webauthn_credential_user_handle_fkey FOREIGN KEY (user_handle_id, user_id) REFERENCES webauthn_credential_user_handles(id, user_id) ON DELETE NO ACTION ON UPDATE CASCADE;")

backend/persistence/models/webauthn_credential.go

+16-14
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,22 @@ import (
1313

1414
// WebauthnCredential is used by pop to map your webauthn_credentials database table to your go code.
1515
type WebauthnCredential struct {
16-
ID string `db:"id" json:"id"`
17-
Name *string `db:"name" json:"name"`
18-
UserId uuid.UUID `db:"user_id" json:"user_id"`
19-
PublicKey string `db:"public_key" json:"public_key"`
20-
AttestationType string `db:"attestation_type" json:"attestation_type"`
21-
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
22-
SignCount int `db:"sign_count" json:"sign_count"`
23-
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
24-
CreatedAt time.Time `db:"created_at" json:"created_at"`
25-
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
26-
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
27-
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
28-
BackupState bool `db:"backup_state" json:"backup_state"`
29-
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
16+
ID string `db:"id" json:"id"`
17+
Name *string `db:"name" json:"name"`
18+
UserId uuid.UUID `db:"user_id" json:"user_id"`
19+
PublicKey string `db:"public_key" json:"public_key"`
20+
AttestationType string `db:"attestation_type" json:"attestation_type"`
21+
AAGUID uuid.UUID `db:"aaguid" json:"aaguid"`
22+
SignCount int `db:"sign_count" json:"sign_count"`
23+
LastUsedAt *time.Time `db:"last_used_at" json:"last_used_at"`
24+
CreatedAt time.Time `db:"created_at" json:"created_at"`
25+
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
26+
Transports Transports `has_many:"webauthn_credential_transports" json:"transports"`
27+
BackupEligible bool `db:"backup_eligible" json:"backup_eligible"`
28+
BackupState bool `db:"backup_state" json:"backup_state"`
29+
MFAOnly bool `db:"mfa_only" json:"mfa_only"`
30+
UserHandleID *uuid.UUID `db:"user_handle_id" json:"-"`
31+
UserHandle *WebauthnCredentialUserHandle `belongs_to:"webauthn_credential_user_handle" fk_id:"webauthn_credential_user_handle_fkey" json:"user_handle,omitempty"`
3032
}
3133

3234
type WebauthnCredentials []WebauthnCredential
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package models
2+
3+
import (
4+
"github.com/gobuffalo/pop/v6"
5+
"github.com/gobuffalo/validate/v3"
6+
"github.com/gobuffalo/validate/v3/validators"
7+
"github.com/gofrs/uuid"
8+
"time"
9+
)
10+
11+
type WebauthnCredentialUserHandle struct {
12+
ID uuid.UUID `db:"id" json:"id"`
13+
UserID uuid.UUID `db:"user_id" json:"user_id"`
14+
Handle string `db:"handle" json:"handle"`
15+
CreatedAt time.Time `db:"created_at" json:"created_at"`
16+
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
17+
}
18+
19+
// Validate gets run every time you call a "pop.Validate*" (pop.ValidateAndSave, pop.ValidateAndCreate, pop.ValidateAndUpdate) method.
20+
func (userHandle *WebauthnCredentialUserHandle) Validate(tx *pop.Connection) (*validate.Errors, error) {
21+
return validate.Validate(
22+
&validators.UUIDIsPresent{Name: "ID", Field: userHandle.ID},
23+
&validators.UUIDIsPresent{Name: "UserId", Field: userHandle.UserID},
24+
&validators.StringIsPresent{Name: "handle", Field: userHandle.Handle},
25+
&validators.TimeIsPresent{Name: "CreatedAt", Field: userHandle.CreatedAt},
26+
&validators.TimeIsPresent{Name: "UpdatedAt", Field: userHandle.UpdatedAt},
27+
), nil
28+
}

backend/persistence/webauthn_credential_persister.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func NewWebauthnCredentialPersister(db *pop.Connection) WebauthnCredentialPersis
2727

2828
func (p *webauthnCredentialPersister) Get(id string) (*models.WebauthnCredential, error) {
2929
credential := models.WebauthnCredential{}
30-
err := p.db.Find(&credential, id)
30+
err := p.db.Eager().Find(&credential, id)
3131
if err != nil && errors.Is(err, sql.ErrNoRows) {
3232
return nil, nil
3333
}

backend/session/session.go

+14-2
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
)
1414

1515
type Manager interface {
16-
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt) (string, jwt.Token, error)
16+
GenerateJWT(userId uuid.UUID, userDto *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error)
1717
Verify(string) (jwt.Token, error)
1818
GenerateCookie(token string) (*http.Cookie, error)
1919
DeleteCookie() (*http.Cookie, error)
@@ -90,7 +90,7 @@ func NewManager(jwkManager hankoJwk.Manager, config config.Config) (Manager, err
9090
}
9191

9292
// GenerateJWT creates a new session JWT for the given user
93-
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jwt.Token, error) {
93+
func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt, opts ...JWTOptions) (string, jwt.Token, error) {
9494
sessionID, err := uuid.NewV4()
9595
if err != nil {
9696
return "", nil, err
@@ -109,6 +109,10 @@ func (m *manager) GenerateJWT(userId uuid.UUID, email *dto.EmailJwt) (string, jw
109109
_ = token.Set("email", &email)
110110
}
111111

112+
for _, opt := range opts {
113+
opt(token)
114+
}
115+
112116
if m.issuer != "" {
113117
_ = token.Set(jwt.IssuerKey, m.issuer)
114118
}
@@ -158,3 +162,11 @@ func (m *manager) DeleteCookie() (*http.Cookie, error) {
158162
MaxAge: -1,
159163
}, nil
160164
}
165+
166+
type JWTOptions func(token jwt.Token)
167+
168+
func WithValue(key string, value interface{}) JWTOptions {
169+
return func(jwt jwt.Token) {
170+
_ = jwt.Set(key, value)
171+
}
172+
}

0 commit comments

Comments
 (0)