Skip to content

Commit

Permalink
feat: add PKCE support
Browse files Browse the repository at this point in the history
  • Loading branch information
stonith404 committed Nov 17, 2024
1 parent 760c8e8 commit 3613ac2
Show file tree
Hide file tree
Showing 15 changed files with 188 additions and 86 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Additionally, what makes Pocket ID special is that it only supports [passkey](ht
## Setup

> [!WARNING]
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue. For example PKCE is not yet implemented.
> Pocket ID is in its early stages and may contain bugs. There might be OIDC features that are not yet implemented. If you encounter any issues, please open an issue.
### Before you start

Expand Down
16 changes: 15 additions & 1 deletion backend/internal/common/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (e *TooManyRequestsError) HttpStatusCode() int { return http.StatusTooManyR
type ClientIdOrSecretNotProvidedError struct{}

func (e *ClientIdOrSecretNotProvidedError) Error() string {
return "Client id and secret not provided"
return "Client id or secret not provided"
}
func (e *ClientIdOrSecretNotProvidedError) HttpStatusCode() int { return http.StatusBadRequest }

Expand Down Expand Up @@ -146,3 +146,17 @@ func (e *AccountEditNotAllowedError) Error() string {
return "You are not allowed to edit your account"
}
func (e *AccountEditNotAllowedError) HttpStatusCode() int { return http.StatusForbidden }

type OidcInvalidCodeVerifierError struct{}

func (e *OidcInvalidCodeVerifierError) Error() string {
return "Invalid code verifier"
}
func (e *OidcInvalidCodeVerifierError) HttpStatusCode() int { return http.StatusBadRequest }

type OidcMissingCodeChallengeError struct{}

func (e *OidcMissingCodeChallengeError) Error() string {
return "Missing code challenge"
}
func (e *OidcMissingCodeChallengeError) HttpStatusCode() int { return http.StatusBadRequest }
17 changes: 7 additions & 10 deletions backend/internal/controller/oidc_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package controller

import (
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
"github.com/stonith404/pocket-id/backend/internal/dto"
"github.com/stonith404/pocket-id/backend/internal/middleware"
"github.com/stonith404/pocket-id/backend/internal/service"
Expand Down Expand Up @@ -80,7 +79,10 @@ func (oc *OidcController) authorizeNewClientHandler(c *gin.Context) {
}

func (oc *OidcController) createTokensHandler(c *gin.Context) {
var input dto.OidcIdTokenDto
// Disable cors for this endpoint
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")

var input dto.OidcCreateTokensDto

if err := c.ShouldBind(&input); err != nil {
c.Error(err)
Expand All @@ -91,16 +93,11 @@ func (oc *OidcController) createTokensHandler(c *gin.Context) {
clientSecret := input.ClientSecret

// Client id and secret can also be passed over the Authorization header
if clientID == "" || clientSecret == "" {
var ok bool
clientID, clientSecret, ok = c.Request.BasicAuth()
if !ok {
c.Error(&common.ClientIdOrSecretNotProvidedError{})
return
}
if clientID == "" && clientSecret == "" {
clientID, clientSecret, _ = c.Request.BasicAuth()
}

idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret)
idToken, accessToken, err := oc.oidcService.CreateTokens(input.Code, input.GrantType, clientID, clientSecret, input.CodeVerifier)
if err != nil {
c.Error(err)
return
Expand Down
15 changes: 10 additions & 5 deletions backend/internal/dto/oidc_dto.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,34 @@ type PublicOidcClientDto struct {
type OidcClientDto struct {
PublicOidcClientDto
CallbackURLs []string `json:"callbackURLs"`
IsPublic bool `json:"isPublic"`
CreatedBy UserDto `json:"createdBy"`
}

type OidcClientCreateDto struct {
Name string `json:"name" binding:"required,max=50"`
CallbackURLs []string `json:"callbackURLs" binding:"required,urlList"`
IsPublic bool `json:"isPublic"`
}

type AuthorizeOidcClientRequestDto struct {
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
ClientID string `json:"clientID" binding:"required"`
Scope string `json:"scope" binding:"required"`
CallbackURL string `json:"callbackURL"`
Nonce string `json:"nonce"`
CodeChallenge string `json:"codeChallenge"`
CodeChallengeMethod string `json:"codeChallengeMethod"`
}

type AuthorizeOidcClientResponseDto struct {
Code string `json:"code"`
CallbackURL string `json:"callbackURL"`
}

type OidcIdTokenDto struct {
type OidcCreateTokensDto struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code" binding:"required"`
ClientID string `form:"client_id"`
ClientSecret string `form:"client_secret"`
CodeVerifier string `form:"code_verifier"`
}
29 changes: 19 additions & 10 deletions backend/internal/middleware/cors.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
package middleware

import (
"github.com/stonith404/pocket-id/backend/internal/common"
"time"

"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/stonith404/pocket-id/backend/internal/common"
)

type CorsMiddleware struct{}
Expand All @@ -15,10 +12,22 @@ func NewCorsMiddleware() *CorsMiddleware {
}

func (m *CorsMiddleware) Add() gin.HandlerFunc {
return cors.New(cors.Config{
AllowOrigins: []string{common.EnvConfig.AppURL},
AllowMethods: []string{"*"},
AllowHeaders: []string{"*"},
MaxAge: 12 * time.Hour,
})
return func(c *gin.Context) {
// Allow all origins for the token endpoint
if c.FullPath() == "/api/oidc/token" {
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
} else {
c.Writer.Header().Set("Access-Control-Allow-Origin", common.EnvConfig.AppURL)
}

c.Writer.Header().Set("Access-Control-Allow-Headers", "*")
c.Writer.Header().Set("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT")

if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}

c.Next()
}
}
11 changes: 7 additions & 4 deletions backend/internal/model/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ type UserAuthorizedOidcClient struct {
type OidcAuthorizationCode struct {
Base

Code string
Scope string
Nonce string
ExpiresAt datatype.DateTime
Code string
Scope string
Nonce string
CodeChallenge *string
CodeChallengeMethodSha256 *bool
ExpiresAt datatype.DateTime

UserID string
User User
Expand All @@ -39,6 +41,7 @@ type OidcClient struct {
CallbackURLs CallbackURLs
ImageType *string
HasLogo bool `gorm:"-"`
IsPublic bool

CreatedByID string
CreatedBy User
Expand Down
83 changes: 62 additions & 21 deletions backend/internal/service/oidc_service.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package service

import (
"crypto/sha256"
"encoding/base64"
"errors"
"fmt"
"github.com/stonith404/pocket-id/backend/internal/common"
Expand Down Expand Up @@ -39,16 +41,20 @@ func (s *OidcService) Authorize(input dto.AuthorizeOidcClientRequestDto, userID,
var userAuthorizedOIDCClient model.UserAuthorizedOidcClient
s.db.Preload("Client").First(&userAuthorizedOIDCClient, "client_id = ? AND user_id = ?", input.ClientID, userID)

if userAuthorizedOIDCClient.Client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}

if userAuthorizedOIDCClient.Scope != input.Scope {
return "", "", &common.OidcMissingAuthorizationError{}
}

callbackURL, err := getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
callbackURL, err := s.getCallbackURL(userAuthorizedOIDCClient.Client, input.CallbackURL)
if err != nil {
return "", "", err
}

code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
Expand All @@ -64,7 +70,11 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return "", "", err
}

callbackURL, err := getCallbackURL(client, input.CallbackURL)
if client.IsPublic && input.CodeChallenge == "" {
return "", "", &common.OidcMissingCodeChallengeError{}
}

callbackURL, err := s.getCallbackURL(client, input.CallbackURL)
if err != nil {
return "", "", err
}
Expand All @@ -83,7 +93,7 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
}
}

code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce)
code, err := s.createAuthorizationCode(input.ClientID, userID, input.Scope, input.Nonce, input.CodeChallenge, input.CodeChallengeMethod)
if err != nil {
return "", "", err
}
Expand All @@ -93,31 +103,41 @@ func (s *OidcService) AuthorizeNewClient(input dto.AuthorizeOidcClientRequestDto
return code, callbackURL, nil
}

func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret string) (string, string, error) {
func (s *OidcService) CreateTokens(code, grantType, clientID, clientSecret, codeVerifier string) (string, string, error) {
if grantType != "authorization_code" {
return "", "", &common.OidcGrantTypeNotSupportedError{}
}

if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}

var client model.OidcClient
if err := s.db.First(&client, "id = ?", clientID).Error; err != nil {
return "", "", err
}

err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
// Verify the client secret if the client is not public
if !client.IsPublic {
if clientID == "" || clientSecret == "" {
return "", "", &common.OidcMissingClientCredentialsError{}
}

err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(clientSecret))
if err != nil {
return "", "", &common.OidcClientSecretInvalidError{}
}
}

var authorizationCodeMetaData model.OidcAuthorizationCode
err = s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
err := s.db.Preload("User").First(&authorizationCodeMetaData, "code = ?", code).Error
if err != nil {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}

// If the client is public, the code verifier must match the code challenge
if client.IsPublic {
if !s.validateCodeVerifier(codeVerifier, *authorizationCodeMetaData.CodeChallenge, *authorizationCodeMetaData.CodeChallengeMethodSha256) {
return "", "", &common.OidcInvalidCodeVerifierError{}
}
}

if authorizationCodeMetaData.ClientID != clientID && authorizationCodeMetaData.ExpiresAt.ToTime().Before(time.Now()) {
return "", "", &common.OidcInvalidAuthorizationCodeError{}
}
Expand Down Expand Up @@ -186,6 +206,7 @@ func (s *OidcService) UpdateClient(clientID string, input dto.OidcClientCreateDt

client.Name = input.Name
client.CallbackURLs = input.CallbackURLs
client.IsPublic = input.IsPublic

if err := s.db.Save(&client).Error; err != nil {
return model.OidcClient{}, err
Expand Down Expand Up @@ -358,19 +379,23 @@ func (s *OidcService) GetUserClaimsForClient(userID string, clientID string) (ma
return claims, nil
}

func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string) (string, error) {
func (s *OidcService) createAuthorizationCode(clientID string, userID string, scope string, nonce string, codeChallenge string, codeChallengeMethod string) (string, error) {
randomString, err := utils.GenerateRandomAlphanumericString(32)
if err != nil {
return "", err
}

codeChallengeMethodSha256 := strings.ToUpper(codeChallengeMethod) == "S256"

oidcAuthorizationCode := model.OidcAuthorizationCode{
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
ExpiresAt: datatype.DateTime(time.Now().Add(15 * time.Minute)),
Code: randomString,
ClientID: clientID,
UserID: userID,
Scope: scope,
Nonce: nonce,
CodeChallenge: &codeChallenge,
CodeChallengeMethodSha256: &codeChallengeMethodSha256,
}

if err := s.db.Create(&oidcAuthorizationCode).Error; err != nil {
Expand All @@ -380,7 +405,23 @@ func (s *OidcService) createAuthorizationCode(clientID string, userID string, sc
return randomString, nil
}

func getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
func (s *OidcService) validateCodeVerifier(codeVerifier, codeChallenge string, codeChallengeMethodSha256 bool) bool {
if !codeChallengeMethodSha256 {
return codeVerifier == codeChallenge
}

// Compute SHA-256 hash of the codeVerifier
h := sha256.New()
h.Write([]byte(codeVerifier))
codeVerifierHash := h.Sum(nil)

// Base64 URL encode the verifier hash
encodedVerifierHash := base64.RawURLEncoding.EncodeToString(codeVerifierHash)

return encodedVerifierHash == codeChallenge
}

func (s *OidcService) getCallbackURL(client model.OidcClient, inputCallbackURL string) (callbackURL string, err error) {
if inputCallbackURL == "" {
return client.CallbackURLs[0], nil
}
Expand Down
3 changes: 3 additions & 0 deletions backend/migrations/20241115131129_pkce.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge;
ALTER TABLE oidc_authorization_codes DROP COLUMN code_challenge_method_sha256;
ALTER TABLE oidc_clients DROP COLUMN is_public;
3 changes: 3 additions & 0 deletions backend/migrations/20241115131129_pkce.up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge TEXT;
ALTER TABLE oidc_authorization_codes ADD COLUMN code_challenge_method_sha256 NUMERIC;
ALTER TABLE oidc_clients ADD COLUMN is_public BOOLEAN DEFAULT FALSE;
12 changes: 8 additions & 4 deletions frontend/src/lib/services/oidc-service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,27 @@ import type { Paginated, PaginationRequest } from '$lib/types/pagination.type';
import APIService from './api-service';

class OidcService extends APIService {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string) {
async authorize(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) {
const res = await this.api.post('/oidc/authorize', {
scope,
nonce,
callbackURL,
clientId
clientId,
codeChallenge,
codeChallengeMethod
});

return res.data as AuthorizeResponse;
}

async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string) {
async authorizeNewClient(clientId: string, scope: string, callbackURL: string, nonce?: string, codeChallenge?: string, codeChallengeMethod?: string) {
const res = await this.api.post('/oidc/authorize/new-client', {
scope,
nonce,
callbackURL,
clientId
clientId,
codeChallenge,
codeChallengeMethod
});

return res.data as AuthorizeResponse;
Expand Down
Loading

0 comments on commit 3613ac2

Please sign in to comment.