Skip to content

Commit

Permalink
feat: add PKCE support
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 committed Nov 15, 2024
1 parent 760c8e8 commit b711e8a
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 30 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Additionally, what makes Pocket ID special is that it only supports [passkey](ht
## Setup

> [!WARNING]
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue. For example PKCE is not yet implemented.
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue.
### Before you start

Expand Down
8 changes: 8 additions & 0 deletions backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,11 @@ func (e *AccountEditNotAllowedError) Error() string {
return "You are not allowed to edit your account"
}
func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden }

type OidcInvalidCodeVerifierError struct{}

func (e *OidcInvalidCodeVerifierError) Error() string {
return "Invalid code verifier"
}

func (e *OidcInvalidCodeVerifierError) HttpStatusCode() int { return http.StatusBadRequest }
4 changes: 2 additions & 2 deletions backend/internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
}

func (oc *OidcController) createTokensHandler(c *gin.Context) {
var input dto.OidcIdTokenDto
var input dto.OidcCreateTokensDto

if err := c.ShouldBind(&input); err != nil {
c.Error(err)
Expand All @@ -100,7 +100,7 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
}
}

idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier)
if err != nil {
c.Error(err)
return
Expand Down
13 changes: 8 additions & 5 deletions backend/internal/dto/oidc_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,23 @@ type OidcClientCreateDto struct {
}

type AuthorizeOidcClientRequestDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}

type AuthorizeOidcClientResponseDto struct {
Code string `json:"code"`
CallbackURL string `json:"callbackURL"`
}

type OidcIdTokenDto struct {
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
}
10 changes: 6 additions & 4 deletions backend/internal/model/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ type UserAuthorizedOidcClient struct {
type OidcAuthorizationCode struct {
Base

Code string
Scope string
Nonce string
ExpiresAt datatype.DateTime
Code string
Scope string
Nonce string
CodeChallenge *string
CodeChallengeMethodSha256 *bool
ExpiresAt datatype.DateTime

UserID string
User User
Expand Down
54 changes: 41 additions & 13 deletions backend/internal/service/oidc_service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package service

import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
Expand Down Expand Up @@ -43,12 +45,12 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
return "", "", &common.OidcMissingAuthorizationError{}
}

callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}

code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
Expand All @@ -64,7 +66,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return "", "", err
}

callbackURL, err := getCallbackURL(client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
Expand All @@ -83,7 +85,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
}
}

code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
Expand All @@ -93,7 +95,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return code, callbackURL, nil
}

func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}
Expand All @@ -118,6 +120,12 @@ func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret strin
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}

if authorizationCodeMetaData.CodeChallenge != nil {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
}
}

if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
Expand Down Expand Up @@ -358,19 +366,23 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}

func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}

codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256"

oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}

if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
Expand All @@ -380,7 +392,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}

func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool {
if !codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}

// Compute SHA-256 hash of the codeVerifier
h := sha256.New()
h.Write([]byte(codeVerifier))
codeVerifierHash := h.Sum(nil)

// Base64 URL encode the verifier hash
encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash)

return encodedVerifierHash == codeChallenge
}

func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}
Expand Down
2 changes: 2 additions & 0 deletions backend/migrations/20241115131129_pkce.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge;
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge_method_sha256;
2 changes: 2 additions & 0 deletions backend/migrations/20241115131129_pkce.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge TEXT;
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge_method_sha256 NUMERIC;
6 changes: 4 additions & 2 deletions frontend/src/lib/services/oidc-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ import type { Paginated, PaginationRequest } from '$lib/types/pagination.type';
import APIService from './api-service';

class OidcService extends APIService {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string) {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) {
const res = await this.api.post('/oidc/authorize', {
scope,
nonce,
callbackURL,
clientId
clientId,
codeChallenge,
codeChallengeMethod
});

return res.data as AuthorizeResponse;
Expand Down
4 changes: 3 additions & 1 deletion frontend/src/routes/authorize/+page.server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ export const load: PageServerLoad = async ({ url, cookies }) => {
nonce: url.searchParams.get('nonce') || undefined,
state: url.searchParams.get('state')!,
callbackURL: url.searchParams.get('redirect_uri')!,
client
client,
codeChallenge: url.searchParams.get('code_challenge')!,
codeChallengeMethod: url.searchParams.get('code_challenge_method')!
};
};
4 changes: 2 additions & 2 deletions frontend/src/routes/authorize/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
let authorizationRequired = false;
export let data: PageData;
let { scope, nonce, client, state, callbackURL } = data;
let { scope, nonce, client, state, callbackURL, codeChallenge, codeChallengeMethod } = data;
async function authorize() {
isLoading = true;
Expand All @@ -37,7 +37,7 @@
}
await oidService
.authorize(client!.id, scope, callbackURL, nonce)
.authorize(client!.id, scope, callbackURL, nonce, codeChallenge, codeChallengeMethod)
.then(async ({ code, callbackURL }) => {
onSuccess(code, callbackURL);
});
Expand Down

0 comments on commit b711e8a

Please sign in to comment.