diff --git a/backend/.mockery.private.yml b/backend/.mockery.private.yml index c61bf1a5..38daa894 100644 --- a/backend/.mockery.private.yml +++ b/backend/.mockery.private.yml @@ -28,3 +28,19 @@ packages: structname: '{{.InterfaceName}}Mock' pkgname: cache filename: "{{.InterfaceName}}_mock_test.go" + + github.com/asgardeo/thunder/internal/oauth/oauth2/introspect: + config: + all: true + dir: internal/oauth/oauth2/introspect + structname: '{{.InterfaceName}}Mock' + pkgname: introspect + filename: "{{.InterfaceName}}_mock_test.go" + + github.com/asgardeo/thunder/internal/oauth/oauth2/authz: + config: + all: true + dir: internal/oauth/oauth2/authz + structname: '{{.InterfaceName}}Mock' + pkgname: authz + filename: "{{.InterfaceName}}_mock_test.go" diff --git a/backend/.mockery.public.yml b/backend/.mockery.public.yml index fbdfd19c..18aa2604 100644 --- a/backend/.mockery.public.yml +++ b/backend/.mockery.public.yml @@ -93,28 +93,12 @@ packages: pkgname: jwksmock filename: "{{.InterfaceName}}_mock.go" - github.com/asgardeo/thunder/internal/oauth/scope/provider: + github.com/asgardeo/thunder/internal/oauth/scope: config: all: true - dir: tests/mocks/oauth/scope/providermock + dir: tests/mocks/oauth/scopemock structname: '{{.InterfaceName}}Mock' - pkgname: providermock - filename: "{{.InterfaceName}}_mock.go" - - github.com/asgardeo/thunder/internal/oauth/scope/validator: - config: - all: true - dir: tests/mocks/oauth/scope/validatormock - structname: '{{.InterfaceName}}Mock' - pkgname: validatormock - filename: "{{.InterfaceName}}_mock.go" - - github.com/asgardeo/thunder/internal/oauth/session/store: - config: - all: true - dir: tests/mocks/oauth/session/storemock - structname: '{{.InterfaceName}}Mock' - pkgname: storemock + pkgname: scopemock filename: "{{.InterfaceName}}_mock.go" github.com/asgardeo/thunder/internal/oauth/oauth2/authz: @@ -125,14 +109,6 @@ packages: pkgname: authzmock filename: "{{.InterfaceName}}_mock.go" - github.com/asgardeo/thunder/internal/oauth/oauth2/authz/store: - config: - all: true - dir: tests/mocks/oauth/oauth2/authz/storemock - structname: '{{.InterfaceName}}Mock' - pkgname: storemock - filename: "{{.InterfaceName}}_mock.go" - github.com/asgardeo/thunder/internal/oauth/oauth2/granthandlers: config: all: true @@ -244,3 +220,11 @@ packages: structname: '{{.InterfaceName}}Mock' pkgname: httpmock filename: "{{.InterfaceName}}_mock.go" + + github.com/asgardeo/thunder/internal/application: + config: + all: true + dir: tests/mocks/applicationmock + structname: '{{.InterfaceName}}Mock' + pkgname: applicationmock + filename: "{{.InterfaceName}}_mock.go" diff --git a/backend/cmd/server/servicemanager.go b/backend/cmd/server/servicemanager.go index 8cd9fcdd..806875d7 100644 --- a/backend/cmd/server/servicemanager.go +++ b/backend/cmd/server/servicemanager.go @@ -29,6 +29,7 @@ import ( "github.com/asgardeo/thunder/internal/group" "github.com/asgardeo/thunder/internal/idp" "github.com/asgardeo/thunder/internal/notification" + "github.com/asgardeo/thunder/internal/oauth" "github.com/asgardeo/thunder/internal/ou" "github.com/asgardeo/thunder/internal/system/jwt" "github.com/asgardeo/thunder/internal/system/log" @@ -63,24 +64,15 @@ func registerServices(mux *http.ServeMux) { _ = flowexec.Initialize(mux, flowMgtService, applicationService) + // Initialize OAuth services. + oauth.Initialize(mux, applicationService, userService, jwtService) + // TODO: Legacy way of initializing services. These need to be refactored in the future aligning to the // dependency injection pattern used above. // Register the health service. services.NewHealthCheckService(mux) - // Register the token service. - services.NewTokenService(mux) - - // Register the authorization service. - services.NewAuthorizationService(mux) - - // Register the JWKS service. - services.NewJWKSAPIService(mux) - - // Register the introspection service. - services.NewIntrospectionAPIService(mux) - // Register the authentication service. services.NewAuthenticationService(mux) } diff --git a/backend/internal/oauth/init.go b/backend/internal/oauth/init.go new file mode 100644 index 00000000..0befde4f --- /dev/null +++ b/backend/internal/oauth/init.go @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Package oauth provides centralized initialization for all OAuth-related services. +package oauth + +import ( + "net/http" + + "github.com/asgardeo/thunder/internal/application" + "github.com/asgardeo/thunder/internal/oauth/jwks" + "github.com/asgardeo/thunder/internal/oauth/oauth2/granthandlers" + "github.com/asgardeo/thunder/internal/oauth/oauth2/introspect" + "github.com/asgardeo/thunder/internal/oauth/oauth2/token" + "github.com/asgardeo/thunder/internal/oauth/scope" + "github.com/asgardeo/thunder/internal/system/jwt" + "github.com/asgardeo/thunder/internal/user" +) + +// Initialize initializes all OAuth-related services and registers their routes. +func Initialize( + mux *http.ServeMux, + applicationService application.ApplicationServiceInterface, + userService user.UserServiceInterface, + jwtService jwt.JWTServiceInterface, +) { + jwks.Initialize(mux) + grantHandlerProvider := granthandlers.Initialize(mux, jwtService, userService, applicationService) + scopeValidator := scope.Initialize() + token.Initialize(mux, applicationService, grantHandlerProvider, scopeValidator) + introspect.Initialize(mux, jwtService) +} diff --git a/backend/internal/oauth/jwks/constants/constants.go b/backend/internal/oauth/jwks/constants.go similarity index 96% rename from backend/internal/oauth/jwks/constants/constants.go rename to backend/internal/oauth/jwks/constants.go index d315c38d..f455bf60 100644 --- a/backend/internal/oauth/jwks/constants/constants.go +++ b/backend/internal/oauth/jwks/constants.go @@ -16,8 +16,7 @@ * under the License. */ -// Package constants defines the constants used in the JWKS service. -package constants +package jwks import "github.com/asgardeo/thunder/internal/system/error/serviceerror" diff --git a/backend/internal/oauth/jwks/handler/handler.go b/backend/internal/oauth/jwks/handler.go similarity index 80% rename from backend/internal/oauth/jwks/handler/handler.go rename to backend/internal/oauth/jwks/handler.go index 3a9ee649..73c98f38 100644 --- a/backend/internal/oauth/jwks/handler/handler.go +++ b/backend/internal/oauth/jwks/handler.go @@ -16,34 +16,32 @@ * under the License. */ -// Package handler provides the HTTP handler for retrieving JSON Web Key Sets (JWKS). -package handler +package jwks import ( "encoding/json" "net/http" - "github.com/asgardeo/thunder/internal/oauth/jwks" serverconst "github.com/asgardeo/thunder/internal/system/constants" "github.com/asgardeo/thunder/internal/system/error/apierror" "github.com/asgardeo/thunder/internal/system/error/serviceerror" "github.com/asgardeo/thunder/internal/system/log" ) -// JWKSHandler handles requests for the JSON Web Key Set (JWKS). -type JWKSHandler struct { - jwksService jwks.JWKSServiceInterface +// jwksHandler handles requests for the JSON Web Key Set (JWKS). +type jwksHandler struct { + jwksService JWKSServiceInterface } -// NewJWKSHandler creates a new instance of JWKSHandler. -func NewJWKSHandler() *JWKSHandler { - return &JWKSHandler{ - jwksService: jwks.NewJWKSService(), +// newJWKSHandler creates a new instance of jwksHandler. +func newJWKSHandler(jwksService JWKSServiceInterface) *jwksHandler { + return &jwksHandler{ + jwksService: jwksService, } } // HandleJWKSRequest handles the HTTP request to retrieve the JSON Web Key Set (JWKS). -func (h *JWKSHandler) HandleJWKSRequest(w http.ResponseWriter, r *http.Request) { +func (h *jwksHandler) HandleJWKSRequest(w http.ResponseWriter, r *http.Request) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "JWKSHandler")) jwksResponse, svcErr := h.jwksService.GetJWKS() @@ -64,7 +62,7 @@ func (h *JWKSHandler) HandleJWKSRequest(w http.ResponseWriter, r *http.Request) } // handleError handles errors by writing an appropriate error response to the HTTP response writer. -func (h *JWKSHandler) handleError(w http.ResponseWriter, logger *log.Logger, +func (h *jwksHandler) handleError(w http.ResponseWriter, logger *log.Logger, svcErr *serviceerror.ServiceError) { w.Header().Set(serverconst.ContentTypeHeaderName, serverconst.ContentTypeJSON) diff --git a/backend/internal/system/services/jwksservice.go b/backend/internal/oauth/jwks/init.go similarity index 63% rename from backend/internal/system/services/jwksservice.go rename to backend/internal/oauth/jwks/init.go index 21954dd3..ff8dc654 100644 --- a/backend/internal/system/services/jwksservice.go +++ b/backend/internal/oauth/jwks/init.go @@ -16,39 +16,32 @@ * under the License. */ -package services +package jwks import ( "net/http" - "github.com/asgardeo/thunder/internal/oauth/jwks/handler" "github.com/asgardeo/thunder/internal/system/middleware" ) -// JWKSAPIService defines the API service for handling JWKS requests. -type JWKSAPIService struct { - jwksHandler *handler.JWKSHandler +// Initialize initializes the JWKS service and registers its routes. +func Initialize(mux *http.ServeMux) JWKSServiceInterface { + // Initialize the JWKS service + jwksService := newJWKSService() + jwksHandler := newJWKSHandler(jwksService) + registerRoutes(mux, jwksHandler) + return jwksService } -// NewJWKSAPIService creates a new instance of JWKSAPIService. -func NewJWKSAPIService(mux *http.ServeMux) ServiceInterface { - instance := &JWKSAPIService{ - jwksHandler: handler.NewJWKSHandler(), - } - instance.RegisterRoutes(mux) - - return instance -} - -// RegisterRoutes registers the routes for the JWKSAPIService. -func (s *JWKSAPIService) RegisterRoutes(mux *http.ServeMux) { +// registerRoutes registers the routes for the JWKSAPIService. +func registerRoutes(mux *http.ServeMux, jwksHandler *jwksHandler) { opts := middleware.CORSOptions{ AllowedMethods: "GET, OPTIONS", AllowedHeaders: "Content-Type, Authorization", AllowCredentials: true, } mux.HandleFunc(middleware.WithCORS("GET /oauth2/jwks", - s.jwksHandler.HandleJWKSRequest, opts)) + jwksHandler.HandleJWKSRequest, opts)) mux.HandleFunc(middleware.WithCORS("OPTIONS /oauth2/jwks", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) diff --git a/backend/internal/oauth/jwks/model/model.go b/backend/internal/oauth/jwks/model.go similarity index 94% rename from backend/internal/oauth/jwks/model/model.go rename to backend/internal/oauth/jwks/model.go index 14379278..5148272f 100644 --- a/backend/internal/oauth/jwks/model/model.go +++ b/backend/internal/oauth/jwks/model.go @@ -16,8 +16,7 @@ * under the License. */ -// Package model defines the data structures used in the JWKS service. -package model +package jwks // JWKS defines the structure of a JSON Web Key Set. type JWKS struct { diff --git a/backend/internal/oauth/jwks/service.go b/backend/internal/oauth/jwks/service.go index 8ffc6197..bcd47826 100644 --- a/backend/internal/oauth/jwks/service.go +++ b/backend/internal/oauth/jwks/service.go @@ -29,9 +29,6 @@ import ( // Use crypto/sha1 only for JWKS x5t as required by spec for thumbprint. "crypto/sha1" //nolint:gosec - "github.com/asgardeo/thunder/internal/cert" - "github.com/asgardeo/thunder/internal/oauth/jwks/constants" - "github.com/asgardeo/thunder/internal/oauth/jwks/model" "github.com/asgardeo/thunder/internal/system/config" "github.com/asgardeo/thunder/internal/system/crypto/hash" "github.com/asgardeo/thunder/internal/system/error/serviceerror" @@ -39,42 +36,38 @@ import ( // JWKSServiceInterface defines the interface for JWKS service. type JWKSServiceInterface interface { - GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError) + GetJWKS() (*JWKSResponse, *serviceerror.ServiceError) } -// JWKSService implements the JWKSServiceInterface. -type JWKSService struct { - SystemCertService cert.SystemCertificateServiceInterface -} +// jwksService implements the JWKSServiceInterface. +type jwksService struct{} -// NewJWKSService creates a new instance of JWKSService. -func NewJWKSService() JWKSServiceInterface { - return &JWKSService{ - SystemCertService: cert.NewSystemCertificateService(), - } +// newJWKSService creates a new instance of JWKSService. +func newJWKSService() JWKSServiceInterface { + return &jwksService{} } // GetJWKS retrieves the JSON Web Key Set (JWKS) from the server's TLS certificate. -func (s *JWKSService) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError) { +func (s *jwksService) GetJWKS() (*JWKSResponse, *serviceerror.ServiceError) { certConfig := config.GetThunderRuntime().CertConfig kid := certConfig.CertKid if kid == "" { - return nil, constants.ErrorCertificateKidNotFound + return nil, ErrorCertificateKidNotFound } tlsConfig := certConfig.TLSConfig if tlsConfig == nil { - return nil, constants.ErrorTLSConfigNotFound + return nil, ErrorTLSConfigNotFound } if len(tlsConfig.Certificates) == 0 || len(tlsConfig.Certificates[0].Certificate) == 0 { - return nil, constants.ErrorNoCertificateFound + return nil, ErrorNoCertificateFound } certData := tlsConfig.Certificates[0].Certificate[0] parsedCert, err := x509.ParseCertificate(certData) if err != nil { - svcErr := constants.ErrorWhileParsingCertificate + svcErr := ErrorWhileParsingCertificate svcErr.ErrorDescription = err.Error() return nil, svcErr } @@ -87,7 +80,7 @@ func (s *JWKSService) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError x5t := base64.StdEncoding.EncodeToString(sha1Sum[:]) x5tS256 := hash.GenerateThumbprint(parsedCert.Raw) - var jwks model.JWKS + var jwks JWKS switch pub := parsedCert.PublicKey.(type) { case *rsa.PublicKey: encodeBase64URL := func(b []byte) string { @@ -107,7 +100,7 @@ func (s *JWKSService) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError } eEnc := encodeBase64URL(eBytes) - jwks = model.JWKS{ + jwks = JWKS{ Kid: kid, Kty: "RSA", Use: "sig", @@ -135,7 +128,7 @@ func (s *JWKSService) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError alg = "ES512" } - jwks = model.JWKS{ + jwks = JWKS{ Kid: kid, Kty: "EC", Use: "sig", @@ -148,10 +141,10 @@ func (s *JWKSService) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError X5tS256: x5tS256, } default: - return nil, constants.ErrorUnsupportedPublicKeyType + return nil, ErrorUnsupportedPublicKeyType } - return &model.JWKSResponse{ - Keys: []model.JWKS{jwks}, + return &JWKSResponse{ + Keys: []JWKS{jwks}, }, nil } diff --git a/backend/internal/oauth/jwks/service_test.go b/backend/internal/oauth/jwks/service_test.go index 00a220b2..253aafca 100644 --- a/backend/internal/oauth/jwks/service_test.go +++ b/backend/internal/oauth/jwks/service_test.go @@ -33,14 +33,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/asgardeo/thunder/internal/cert" - "github.com/asgardeo/thunder/internal/oauth/jwks/constants" "github.com/asgardeo/thunder/internal/system/config" ) type JWKSServiceTestSuite struct { suite.Suite - jwksService *JWKSService + jwksService *jwksService } func TestJWKSServiceSuite(t *testing.T) { @@ -48,9 +46,7 @@ func TestJWKSServiceSuite(t *testing.T) { } func (suite *JWKSServiceTestSuite) SetupTest() { - suite.jwksService = &JWKSService{ - SystemCertService: cert.NewSystemCertificateService(), - } + suite.jwksService = &jwksService{} } func (suite *JWKSServiceTestSuite) setupRuntimeConfig(tlsConfig *tls.Config, certKid string) error { @@ -75,7 +71,7 @@ func (suite *JWKSServiceTestSuite) setupRuntimeConfig(tlsConfig *tls.Config, cer } func (suite *JWKSServiceTestSuite) TestNewJWKSService() { - service := NewJWKSService() + service := newJWKSService() assert.NotNil(suite.T(), service) assert.Implements(suite.T(), (*JWKSServiceInterface)(nil), service) } @@ -194,7 +190,7 @@ func (suite *JWKSServiceTestSuite) TestGetJWKS_NoCertificatesInTLSConfig() { result, svcErr := suite.jwksService.GetJWKS() assert.Nil(suite.T(), result) assert.NotNil(suite.T(), svcErr) - assert.Equal(suite.T(), constants.ErrorNoCertificateFound.Code, svcErr.Code) + assert.Equal(suite.T(), ErrorNoCertificateFound.Code, svcErr.Code) } func (suite *JWKSServiceTestSuite) TestGetJWKS_EmptyCertificateInTLSConfig() { @@ -214,7 +210,7 @@ func (suite *JWKSServiceTestSuite) TestGetJWKS_EmptyCertificateInTLSConfig() { result, svcErr := suite.jwksService.GetJWKS() assert.Nil(suite.T(), result) assert.NotNil(suite.T(), svcErr) - assert.Equal(suite.T(), constants.ErrorNoCertificateFound.Code, svcErr.Code) + assert.Equal(suite.T(), ErrorNoCertificateFound.Code, svcErr.Code) } func (suite *JWKSServiceTestSuite) TestGetJWKS_InvalidCertificateData() { @@ -234,7 +230,7 @@ func (suite *JWKSServiceTestSuite) TestGetJWKS_InvalidCertificateData() { result, svcErr := suite.jwksService.GetJWKS() assert.Nil(suite.T(), result) assert.NotNil(suite.T(), svcErr) - assert.Equal(suite.T(), constants.ErrorWhileParsingCertificate.Code, svcErr.Code) + assert.Equal(suite.T(), ErrorWhileParsingCertificate.Code, svcErr.Code) } func (suite *JWKSServiceTestSuite) TestGetJWKS_CertKidNotFound() { @@ -273,7 +269,7 @@ func (suite *JWKSServiceTestSuite) TestGetJWKS_CertKidNotFound() { result, svcErr := suite.jwksService.GetJWKS() assert.Nil(suite.T(), result) assert.NotNil(suite.T(), svcErr) - assert.Equal(suite.T(), constants.ErrorCertificateKidNotFound.Code, svcErr.Code) + assert.Equal(suite.T(), ErrorCertificateKidNotFound.Code, svcErr.Code) } func (suite *JWKSServiceTestSuite) TestGetJWKS_TLSConfigNotFound() { @@ -284,9 +280,9 @@ func (suite *JWKSServiceTestSuite) TestGetJWKS_TLSConfigNotFound() { result, svcErr := suite.jwksService.GetJWKS() assert.Nil(suite.T(), result) assert.NotNil(suite.T(), svcErr) - assert.Equal(suite.T(), constants.ErrorTLSConfigNotFound.Code, svcErr.Code) + assert.Equal(suite.T(), ErrorTLSConfigNotFound.Code, svcErr.Code) } func (suite *JWKSServiceTestSuite) TestJWKSServiceInterface() { - var _ JWKSServiceInterface = &JWKSService{} + var _ JWKSServiceInterface = &jwksService{} } diff --git a/backend/internal/oauth/oauth2/authz/AuthorizationCodeStoreInterface_mock_test.go b/backend/internal/oauth/oauth2/authz/AuthorizationCodeStoreInterface_mock_test.go new file mode 100644 index 00000000..dc80c2ea --- /dev/null +++ b/backend/internal/oauth/oauth2/authz/AuthorizationCodeStoreInterface_mock_test.go @@ -0,0 +1,306 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authz + +import ( + mock "github.com/stretchr/testify/mock" +) + +// NewAuthorizationCodeStoreInterfaceMock creates a new instance of AuthorizationCodeStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthorizationCodeStoreInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthorizationCodeStoreInterfaceMock { + mock := &AuthorizationCodeStoreInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// AuthorizationCodeStoreInterfaceMock is an autogenerated mock type for the AuthorizationCodeStoreInterface type +type AuthorizationCodeStoreInterfaceMock struct { + mock.Mock +} + +type AuthorizationCodeStoreInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthorizationCodeStoreInterfaceMock) EXPECT() *AuthorizationCodeStoreInterfaceMock_Expecter { + return &AuthorizationCodeStoreInterfaceMock_Expecter{mock: &_m.Mock} +} + +// DeactivateAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock +func (_mock *AuthorizationCodeStoreInterfaceMock) DeactivateAuthorizationCode(authzCode AuthorizationCode) error { + ret := _mock.Called(authzCode) + + if len(ret) == 0 { + panic("no return value specified for DeactivateAuthorizationCode") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(AuthorizationCode) error); ok { + r0 = returnFunc(authzCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeactivateAuthorizationCode' +type AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call struct { + *mock.Call +} + +// DeactivateAuthorizationCode is a helper method to define mock.On call +// - authzCode AuthorizationCode +func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) DeactivateAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { + return &AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call{Call: _e.mock.On("DeactivateAuthorizationCode", authzCode)} +} + +func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) Run(run func(authzCode AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 AuthorizationCode + if args[0] != nil { + arg0 = args[0].(AuthorizationCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) Return(err error) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { + _c.Call.Return(err) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) RunAndReturn(run func(authzCode AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { + _c.Call.Return(run) + return _c +} + +// ExpireAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock +func (_mock *AuthorizationCodeStoreInterfaceMock) ExpireAuthorizationCode(authzCode AuthorizationCode) error { + ret := _mock.Called(authzCode) + + if len(ret) == 0 { + panic("no return value specified for ExpireAuthorizationCode") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(AuthorizationCode) error); ok { + r0 = returnFunc(authzCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ExpireAuthorizationCode' +type AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call struct { + *mock.Call +} + +// ExpireAuthorizationCode is a helper method to define mock.On call +// - authzCode AuthorizationCode +func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) ExpireAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { + return &AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call{Call: _e.mock.On("ExpireAuthorizationCode", authzCode)} +} + +func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) Run(run func(authzCode AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 AuthorizationCode + if args[0] != nil { + arg0 = args[0].(AuthorizationCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) Return(err error) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { + _c.Call.Return(err) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) RunAndReturn(run func(authzCode AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { + _c.Call.Return(run) + return _c +} + +// GetAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock +func (_mock *AuthorizationCodeStoreInterfaceMock) GetAuthorizationCode(clientID string, authCode string) (AuthorizationCode, error) { + ret := _mock.Called(clientID, authCode) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizationCode") + } + + var r0 AuthorizationCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string) (AuthorizationCode, error)); ok { + return returnFunc(clientID, authCode) + } + if returnFunc, ok := ret.Get(0).(func(string, string) AuthorizationCode); ok { + r0 = returnFunc(clientID, authCode) + } else { + r0 = ret.Get(0).(AuthorizationCode) + } + if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { + r1 = returnFunc(clientID, authCode) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizationCode' +type AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call struct { + *mock.Call +} + +// GetAuthorizationCode is a helper method to define mock.On call +// - clientID string +// - authCode string +func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) GetAuthorizationCode(clientID interface{}, authCode interface{}) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { + return &AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call{Call: _e.mock.On("GetAuthorizationCode", clientID, authCode)} +} + +func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) Run(run func(clientID string, authCode string)) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) Return(authorizationCode AuthorizationCode, err error) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { + _c.Call.Return(authorizationCode, err) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) RunAndReturn(run func(clientID string, authCode string) (AuthorizationCode, error)) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { + _c.Call.Return(run) + return _c +} + +// InsertAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock +func (_mock *AuthorizationCodeStoreInterfaceMock) InsertAuthorizationCode(authzCode AuthorizationCode) error { + ret := _mock.Called(authzCode) + + if len(ret) == 0 { + panic("no return value specified for InsertAuthorizationCode") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(AuthorizationCode) error); ok { + r0 = returnFunc(authzCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'InsertAuthorizationCode' +type AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call struct { + *mock.Call +} + +// InsertAuthorizationCode is a helper method to define mock.On call +// - authzCode AuthorizationCode +func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) InsertAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { + return &AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call{Call: _e.mock.On("InsertAuthorizationCode", authzCode)} +} + +func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) Run(run func(authzCode AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 AuthorizationCode + if args[0] != nil { + arg0 = args[0].(AuthorizationCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) Return(err error) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { + _c.Call.Return(err) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) RunAndReturn(run func(authzCode AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { + _c.Call.Return(run) + return _c +} + +// RevokeAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock +func (_mock *AuthorizationCodeStoreInterfaceMock) RevokeAuthorizationCode(authzCode AuthorizationCode) error { + ret := _mock.Called(authzCode) + + if len(ret) == 0 { + panic("no return value specified for RevokeAuthorizationCode") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(AuthorizationCode) error); ok { + r0 = returnFunc(authzCode) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'RevokeAuthorizationCode' +type AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call struct { + *mock.Call +} + +// RevokeAuthorizationCode is a helper method to define mock.On call +// - authzCode AuthorizationCode +func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) RevokeAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { + return &AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call{Call: _e.mock.On("RevokeAuthorizationCode", authzCode)} +} + +func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) Run(run func(authzCode AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 AuthorizationCode + if args[0] != nil { + arg0 = args[0].(AuthorizationCode) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) Return(err error) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { + _c.Call.Return(err) + return _c +} + +func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) RunAndReturn(run func(authzCode AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/oauth/oauth2/authz/AuthorizeServiceInterface_mock_test.go b/backend/internal/oauth/oauth2/authz/AuthorizeServiceInterface_mock_test.go new file mode 100644 index 00000000..96c51fdf --- /dev/null +++ b/backend/internal/oauth/oauth2/authz/AuthorizeServiceInterface_mock_test.go @@ -0,0 +1,104 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authz + +import ( + mock "github.com/stretchr/testify/mock" +) + +// NewAuthorizeServiceInterfaceMock creates a new instance of AuthorizeServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthorizeServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthorizeServiceInterfaceMock { + mock := &AuthorizeServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// AuthorizeServiceInterfaceMock is an autogenerated mock type for the AuthorizeServiceInterface type +type AuthorizeServiceInterfaceMock struct { + mock.Mock +} + +type AuthorizeServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthorizeServiceInterfaceMock) EXPECT() *AuthorizeServiceInterfaceMock_Expecter { + return &AuthorizeServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// GetAuthorizationCodeDetails provides a mock function for the type AuthorizeServiceInterfaceMock +func (_mock *AuthorizeServiceInterfaceMock) GetAuthorizationCodeDetails(clientID string, code string) (*AuthorizationCode, error) { + ret := _mock.Called(clientID, code) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizationCodeDetails") + } + + var r0 *AuthorizationCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string) (*AuthorizationCode, error)); ok { + return returnFunc(clientID, code) + } + if returnFunc, ok := ret.Get(0).(func(string, string) *AuthorizationCode); ok { + r0 = returnFunc(clientID, code) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*AuthorizationCode) + } + } + if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { + r1 = returnFunc(clientID, code) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizationCodeDetails' +type AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call struct { + *mock.Call +} + +// GetAuthorizationCodeDetails is a helper method to define mock.On call +// - clientID string +// - code string +func (_e *AuthorizeServiceInterfaceMock_Expecter) GetAuthorizationCodeDetails(clientID interface{}, code interface{}) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + return &AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call{Call: _e.mock.On("GetAuthorizationCodeDetails", clientID, code)} +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) Run(run func(clientID string, code string)) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) Return(authorizationCode *AuthorizationCode, err error) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Return(authorizationCode, err) + return _c +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) RunAndReturn(run func(clientID string, code string) (*AuthorizationCode, error)) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/oauth/oauth2/authz/constants/constants.go b/backend/internal/oauth/oauth2/authz/constants.go similarity index 91% rename from backend/internal/oauth/oauth2/authz/constants/constants.go rename to backend/internal/oauth/oauth2/authz/constants.go index a1516712..06a0864a 100644 --- a/backend/internal/oauth/oauth2/authz/constants/constants.go +++ b/backend/internal/oauth/oauth2/authz/constants.go @@ -16,8 +16,7 @@ * under the License. */ -// Package constants defines constants related to OAuth2 authorization. -package constants +package authz import "errors" diff --git a/backend/internal/oauth/oauth2/authz/authzhandler.go b/backend/internal/oauth/oauth2/authz/handler.go similarity index 79% rename from backend/internal/oauth/oauth2/authz/authzhandler.go rename to backend/internal/oauth/oauth2/authz/handler.go index b484b3a6..8b8ed6c4 100644 --- a/backend/internal/oauth/oauth2/authz/authzhandler.go +++ b/backend/internal/oauth/oauth2/authz/handler.go @@ -16,7 +16,6 @@ * under the License. */ -// Package authz provides handlers and utilities for managing OAuth2 authorization requests. package authz import ( @@ -28,15 +27,9 @@ import ( "time" "github.com/asgardeo/thunder/internal/application" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/store" oauth2const "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" oauth2model "github.com/asgardeo/thunder/internal/oauth/oauth2/model" oauth2utils "github.com/asgardeo/thunder/internal/oauth/oauth2/utils" - sessionmodel "github.com/asgardeo/thunder/internal/oauth/session/model" - sessionstore "github.com/asgardeo/thunder/internal/oauth/session/store" - sessionutils "github.com/asgardeo/thunder/internal/oauth/session/utils" "github.com/asgardeo/thunder/internal/system/config" serverconst "github.com/asgardeo/thunder/internal/system/constants" "github.com/asgardeo/thunder/internal/system/jwt" @@ -53,28 +46,32 @@ type AuthorizeHandlerInterface interface { HandleAuthorizePostRequest(w http.ResponseWriter, r *http.Request) } -// AuthorizeHandler implements the AuthorizeHandlerInterface for handling OAuth2 authorization requests. -type AuthorizeHandler struct { - AppProvider application.ApplicationProviderInterface - AuthZValidator AuthorizationValidatorInterface - AuthZStore store.AuthorizationCodeStoreInterface - SessionStore sessionstore.SessionDataStoreInterface - JWTService jwt.JWTServiceInterface +// authorizeHandler implements the AuthorizeHandlerInterface for handling OAuth2 authorization requests. +type authorizeHandler struct { + appService application.ApplicationServiceInterface + authZValidator AuthorizationValidatorInterface + authZStore AuthorizationCodeStoreInterface + sessionStore sessionDataStoreInterface + jwtService jwt.JWTServiceInterface } -// NewAuthorizeHandler creates a new instance of AuthorizeHandler. -func NewAuthorizeHandler() AuthorizeHandlerInterface { - return &AuthorizeHandler{ - AppProvider: application.NewApplicationProvider(), - AuthZValidator: NewAuthorizationValidator(), - AuthZStore: store.NewAuthorizationCodeStore(), - SessionStore: sessionstore.GetSessionDataStore(), - JWTService: jwt.GetJWTService(), +// newAuthorizeHandler creates a new instance of authorizeHandler with injected dependencies. +func newAuthorizeHandler( + appService application.ApplicationServiceInterface, + jwtService jwt.JWTServiceInterface, + authZStore AuthorizationCodeStoreInterface, +) AuthorizeHandlerInterface { + return &authorizeHandler{ + appService: appService, + authZValidator: newAuthorizationValidator(), + authZStore: authZStore, + sessionStore: newSessionDataStore(), + jwtService: jwtService, } } // HandleAuthorizeGetRequest handles the GET request for OAuth2 authorization. -func (ah *AuthorizeHandler) HandleAuthorizeGetRequest(w http.ResponseWriter, r *http.Request) { +func (ah *authorizeHandler) HandleAuthorizeGetRequest(w http.ResponseWriter, r *http.Request) { oAuthMessage := ah.getOAuthMessage(r, w) if oAuthMessage == nil { return @@ -83,7 +80,7 @@ func (ah *AuthorizeHandler) HandleAuthorizeGetRequest(w http.ResponseWriter, r * } // HandleAuthorizePostRequest handles the POST request for OAuth2 authorization. -func (ah *AuthorizeHandler) HandleAuthorizePostRequest(w http.ResponseWriter, r *http.Request) { +func (ah *authorizeHandler) HandleAuthorizePostRequest(w http.ResponseWriter, r *http.Request) { oAuthMessage := ah.getOAuthMessage(r, w) if oAuthMessage == nil { return @@ -104,7 +101,7 @@ func (ah *AuthorizeHandler) HandleAuthorizePostRequest(w http.ResponseWriter, r } // handleInitialAuthorizationRequest handles the initial authorization request from the client. -func (ah *AuthorizeHandler) handleInitialAuthorizationRequest(msg *model.OAuthMessage, +func (ah *authorizeHandler) handleInitialAuthorizationRequest(msg *OAuthMessage, w http.ResponseWriter, r *http.Request) { // Extract required parameters. clientID := msg.RequestQueryParams[oauth2const.RequestParamClientID] @@ -123,15 +120,14 @@ func (ah *AuthorizeHandler) handleInitialAuthorizationRequest(msg *model.OAuthMe } // Retrieve the OAuth application based on the client Id. - appService := ah.AppProvider.GetApplicationService() - app, svcErr := appService.GetOAuthApplication(clientID) + app, svcErr := ah.appService.GetOAuthApplication(clientID) if svcErr != nil || app == nil { ah.redirectToErrorPage(w, r, oauth2const.ErrorInvalidClient, "Invalid client_id") return } // Validate the authorization request. - sendErrorToApp, errorCode, errorMessage := ah.AuthZValidator.validateInitialAuthorizationRequest(msg, app) + sendErrorToApp, errorCode, errorMessage := ah.authZValidator.validateInitialAuthorizationRequest(msg, app) if errorCode != "" { if sendErrorToApp && redirectURI != "" { // Redirect to the redirect URI with an error. @@ -157,7 +153,6 @@ func (ah *AuthorizeHandler) handleInitialAuthorizationRequest(msg *model.OAuthMe // Construct session data. oauthParams := oauth2model.OAuthParameters{ - SessionDataKey: sessionutils.GenerateNewSessionDataKey(), State: state, ClientID: clientID, RedirectURI: redirectURI, @@ -173,17 +168,17 @@ func (ah *AuthorizeHandler) handleInitialAuthorizationRequest(msg *model.OAuthMe oauthParams.RedirectURI = app.RedirectURIs[0] } - sessionData := sessionmodel.SessionData{ + sessionData := SessionData{ OAuthParameters: oauthParams, AuthTime: time.Now(), } // Store session data in the session store. - ah.SessionStore.AddSession(oauthParams.SessionDataKey, sessionData) + identifier := ah.sessionStore.AddSession(sessionData) // Add required query parameters. queryParams := make(map[string]string) - queryParams[oauth2const.SessionDataKey] = oauthParams.SessionDataKey + queryParams[oauth2const.SessionDataKey] = identifier queryParams[oauth2const.AppID] = app.AppID // Add insecure warning if the redirect URI is not using TLS. @@ -201,13 +196,13 @@ func (ah *AuthorizeHandler) handleInitialAuthorizationRequest(msg *model.OAuthMe } // handleAuthorizationResponseFromEngine handles the authorization response from the engine. -func (ah *AuthorizeHandler) handleAuthorizationResponseFromEngine(msg *model.OAuthMessage, +func (ah *authorizeHandler) handleAuthorizationResponseFromEngine(msg *OAuthMessage, w http.ResponseWriter) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) // Validate the session data. - sessionData := msg.SessionData - if sessionData == nil { + sessionData, err := ah.loadSessionData(msg.SessionDataKey) + if err != nil { ah.writeAuthZResponseToErrorPage(w, oauth2const.ErrorInvalidRequest, "Invalid authorization request", nil) return } @@ -221,7 +216,7 @@ func (ah *AuthorizeHandler) handleAuthorizationResponseFromEngine(msg *model.OAu } // Verify the assertion. - err := ah.verifyAssertion(assertion, logger) + err = ah.verifyAssertion(assertion, logger) if err != nil { ah.writeAuthZResponseToErrorPage(w, oauth2const.ErrorInvalidRequest, err.Error(), sessionData) return @@ -245,7 +240,7 @@ func (ah *AuthorizeHandler) handleAuthorizationResponseFromEngine(msg *model.OAu // Should validate for the scopes as well. // Generate the authorization code. - authzCode, err := getAuthorizationCode(msg, userID) + authzCode, err := createAuthorizationCode(sessionData, userID) if err != nil { logger.Error("Failed to generate authorization code", log.Error(err)) ah.writeAuthZResponseToErrorPage(w, oauth2const.ErrorServerError, "Failed to generate authorization code", @@ -254,7 +249,7 @@ func (ah *AuthorizeHandler) handleAuthorizationResponseFromEngine(msg *model.OAu } // Persist the authorization code. - persistErr := ah.AuthZStore.InsertAuthorizationCode(authzCode) + persistErr := ah.authZStore.InsertAuthorizationCode(authzCode) if persistErr != nil { logger.Error("Failed to persist authorization code", log.Error(persistErr)) ah.writeAuthZResponseToErrorPage(w, oauth2const.ErrorServerError, "Failed to persist authorization code", @@ -271,8 +266,19 @@ func (ah *AuthorizeHandler) handleAuthorizationResponseFromEngine(msg *model.OAu ah.writeAuthZResponse(w, redirectURI) } +func (ah *authorizeHandler) loadSessionData(sessionDataKey string) (*SessionData, error) { + ok, sessionData := ah.sessionStore.GetSession(sessionDataKey) + if !ok { + return nil, fmt.Errorf("session data not found for session data key: %s", sessionDataKey) + } + + // Remove the session data after retrieval. + ah.sessionStore.ClearSession(sessionDataKey) + return &sessionData, nil +} + // getOAuthMessage extracts the OAuth message from the request and response writer. -func (ah *AuthorizeHandler) getOAuthMessage(r *http.Request, w http.ResponseWriter) *model.OAuthMessage { +func (ah *authorizeHandler) getOAuthMessage(r *http.Request, w http.ResponseWriter) *OAuthMessage { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) if r == nil || w == nil { @@ -280,7 +286,7 @@ func (ah *AuthorizeHandler) getOAuthMessage(r *http.Request, w http.ResponseWrit return nil } - var msg *model.OAuthMessage + var msg *OAuthMessage var err error switch r.Method { @@ -301,7 +307,7 @@ func (ah *AuthorizeHandler) getOAuthMessage(r *http.Request, w http.ResponseWrit } // getOAuthMessageForGetRequest extracts the OAuth message from a authorization GET request. -func (ah *AuthorizeHandler) getOAuthMessageForGetRequest(r *http.Request) (*model.OAuthMessage, error) { +func (ah *authorizeHandler) getOAuthMessageForGetRequest(r *http.Request) (*OAuthMessage, error) { if err := r.ParseForm(); err != nil { return nil, errors.New("failed to parse form data: " + err.Error()) } @@ -313,16 +319,15 @@ func (ah *AuthorizeHandler) getOAuthMessageForGetRequest(r *http.Request) (*mode } } - return &model.OAuthMessage{ + return &OAuthMessage{ RequestType: oauth2const.TypeInitialAuthorizationRequest, - SessionData: nil, RequestQueryParams: queryParams, }, nil } // getOAuthMessageForPostRequest extracts the OAuth message from a authorization POST request. -func (ah *AuthorizeHandler) getOAuthMessageForPostRequest(r *http.Request) (*model.OAuthMessage, error) { - authZReq, err := systemutils.DecodeJSONBody[model.AuthZPostRequest](r) +func (ah *authorizeHandler) getOAuthMessageForPostRequest(r *http.Request) (*OAuthMessage, error) { + authZReq, err := systemutils.DecodeJSONBody[AuthZPostRequest](r) if err != nil { return nil, fmt.Errorf("failed to decode JSON body: %w", err) } @@ -335,22 +340,13 @@ func (ah *AuthorizeHandler) getOAuthMessageForPostRequest(r *http.Request) (*mod // TODO: Require to handle other types such as user consent, etc. requestType := oauth2const.TypeAuthorizationResponseFromEngine - sessionDataKey := authZReq.SessionDataKey - ok, sessionData := ah.SessionStore.GetSession(sessionDataKey) - if !ok { - return nil, fmt.Errorf("session data not found for session data key: %s", sessionDataKey) - } - - // Remove the session data after retrieval. - ah.SessionStore.ClearSession(sessionDataKey) - bodyParams := map[string]string{ oauth2const.Assertion: authZReq.Assertion, } - return &model.OAuthMessage{ + return &OAuthMessage{ RequestType: requestType, - SessionData: &sessionData, + SessionDataKey: authZReq.SessionDataKey, RequestBodyParams: bodyParams, }, nil } @@ -368,7 +364,7 @@ func getLoginPageRedirectURI(queryParams map[string]string) (string, error) { } // redirectToLoginPage constructs the login page URL and redirects the user to it. -func (ah *AuthorizeHandler) redirectToLoginPage(w http.ResponseWriter, r *http.Request, +func (ah *authorizeHandler) redirectToLoginPage(w http.ResponseWriter, r *http.Request, queryParams map[string]string) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) @@ -405,7 +401,7 @@ func getErrorPageRedirectURL(code, msg string) (string, error) { } // redirectToErrorPage constructs the error page URL and redirects the user to it. -func (ah *AuthorizeHandler) redirectToErrorPage(w http.ResponseWriter, r *http.Request, code, msg string) { +func (ah *authorizeHandler) redirectToErrorPage(w http.ResponseWriter, r *http.Request, code, msg string) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) if w == nil || r == nil { @@ -425,10 +421,10 @@ func (ah *AuthorizeHandler) redirectToErrorPage(w http.ResponseWriter, r *http.R } // writeAuthZResponse writes the authorization response to the HTTP response writer. -func (ah *AuthorizeHandler) writeAuthZResponse(w http.ResponseWriter, redirectURI string) { +func (ah *authorizeHandler) writeAuthZResponse(w http.ResponseWriter, redirectURI string) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) - authZResp := model.AuthZPostResponse{ + authZResp := AuthZPostResponse{ RedirectURI: redirectURI, } @@ -444,8 +440,8 @@ func (ah *AuthorizeHandler) writeAuthZResponse(w http.ResponseWriter, redirectUR } // writeAuthZResponseToErrorPage writes the authorization response to the error page. -func (ah *AuthorizeHandler) writeAuthZResponseToErrorPage(w http.ResponseWriter, code, msg string, - sessionData *sessionmodel.SessionData) { +func (ah *authorizeHandler) writeAuthZResponseToErrorPage(w http.ResponseWriter, code, msg string, + sessionData *SessionData) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) redirectURI, err := getErrorPageRedirectURL(code, msg) @@ -462,41 +458,31 @@ func (ah *AuthorizeHandler) writeAuthZResponseToErrorPage(w http.ResponseWriter, ah.writeAuthZResponse(w, redirectURI) } -// getAuthorizationCode generates an authorization code based on the provided OAuth message. -func getAuthorizationCode(oAuthMessage *model.OAuthMessage, authUserID string) ( - model.AuthorizationCode, error) { - sessionData := oAuthMessage.SessionData +// createAuthorizationCode generates an authorization code based on the provided Session data and authenticated user. +func createAuthorizationCode(sessionData *SessionData, authUserID string) ( + AuthorizationCode, error) { clientID := sessionData.OAuthParameters.ClientID - if clientID == "" { - clientID = oAuthMessage.RequestQueryParams["client_id"] - } redirectURI := sessionData.OAuthParameters.RedirectURI - if redirectURI == "" { - redirectURI = oAuthMessage.RequestQueryParams["redirect_uri"] - } if clientID == "" || redirectURI == "" { - return model.AuthorizationCode{}, errors.New("client_id or redirect_uri is missing") + return AuthorizationCode{}, errors.New("client_id or redirect_uri is missing") } if authUserID == "" { - return model.AuthorizationCode{}, errors.New("authenticated user not found") + return AuthorizationCode{}, errors.New("authenticated user not found") } authTime := sessionData.AuthTime if authTime.IsZero() { - return model.AuthorizationCode{}, errors.New("authentication time is not set") + return AuthorizationCode{}, errors.New("authentication time is not set") } scope := sessionData.OAuthParameters.Scopes - if scope == "" { - scope = oAuthMessage.RequestQueryParams["scope"] - } // TODO: Add expiry time logic. expiryTime := authTime.Add(10 * time.Minute) - return model.AuthorizationCode{ + return AuthorizationCode{ CodeID: utils.GenerateUUID(), Code: utils.GenerateUUID(), ClientID: clientID, @@ -505,15 +491,15 @@ func getAuthorizationCode(oAuthMessage *model.OAuthMessage, authUserID string) ( TimeCreated: authTime, ExpiryTime: expiryTime, Scopes: scope, - State: constants.AuthCodeStateActive, + State: AuthCodeStateActive, CodeChallenge: sessionData.OAuthParameters.CodeChallenge, CodeChallengeMethod: sessionData.OAuthParameters.CodeChallengeMethod, }, nil } // verifyAssertion verifies the JWT assertion. -func (ah *AuthorizeHandler) verifyAssertion(assertion string, logger *log.Logger) error { - if err := ah.JWTService.VerifyJWT(assertion, "", ""); err != nil { +func (ah *authorizeHandler) verifyAssertion(assertion string, logger *log.Logger) error { + if err := ah.jwtService.VerifyJWT(assertion, "", ""); err != nil { logger.Debug("Invalid assertion signature", log.Error(err)) return errors.New("invalid assertion signature") } diff --git a/backend/internal/oauth/oauth2/authz/authzhandler_test.go b/backend/internal/oauth/oauth2/authz/handler_test.go similarity index 69% rename from backend/internal/oauth/oauth2/authz/authzhandler_test.go rename to backend/internal/oauth/oauth2/authz/handler_test.go index 091e1596..95393e4d 100644 --- a/backend/internal/oauth/oauth2/authz/authzhandler_test.go +++ b/backend/internal/oauth/oauth2/authz/handler_test.go @@ -29,17 +29,20 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" + "github.com/asgardeo/thunder/tests/mocks/applicationmock" + "github.com/asgardeo/thunder/tests/mocks/jwtmock" + oauth2const "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" oauth2model "github.com/asgardeo/thunder/internal/oauth/oauth2/model" - sessionmodel "github.com/asgardeo/thunder/internal/oauth/session/model" "github.com/asgardeo/thunder/internal/system/config" ) type AuthorizeHandlerTestSuite struct { suite.Suite - handler *AuthorizeHandler + handler *authorizeHandler + mockAppService *applicationmock.ApplicationServiceInterfaceMock + mockJWTService *jwtmock.JWTServiceInterfaceMock + mockAuthzCodeStore *AuthorizationCodeStoreInterfaceMock } func TestAuthorizeHandlerTestSuite(t *testing.T) { @@ -59,11 +62,18 @@ func (suite *AuthorizeHandlerTestSuite) SetupTest() { } _ = config.InitializeThunderRuntime("test", testConfig) - suite.handler = NewAuthorizeHandler().(*AuthorizeHandler) + // Create mocked dependencies for testing + suite.mockAppService = applicationmock.NewApplicationServiceInterfaceMock(suite.T()) + suite.mockJWTService = jwtmock.NewJWTServiceInterfaceMock(suite.T()) + suite.mockAuthzCodeStore = NewAuthorizationCodeStoreInterfaceMock(suite.T()) + + suite.handler = newAuthorizeHandler( + suite.mockAppService, suite.mockJWTService, suite.mockAuthzCodeStore).(*authorizeHandler) } -func (suite *AuthorizeHandlerTestSuite) TestNewAuthorizeHandler() { - handler := NewAuthorizeHandler() +func (suite *AuthorizeHandlerTestSuite) TestnewAuthorizeHandler() { + mockStore := NewAuthorizationCodeStoreInterfaceMock(suite.T()) + handler := newAuthorizeHandler(suite.mockAppService, suite.mockJWTService, mockStore) assert.NotNil(suite.T(), handler) assert.Implements(suite.T(), (*AuthorizeHandlerInterface)(nil), handler) } @@ -75,10 +85,12 @@ func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForGetRequest_Success assert.NoError(suite.T(), err) assert.NotNil(suite.T(), msg) - assert.Equal(suite.T(), oauth2const.TypeInitialAuthorizationRequest, msg.RequestType) - assert.Equal(suite.T(), "test-client", msg.RequestQueryParams["client_id"]) - assert.Equal(suite.T(), "https://example.com", msg.RequestQueryParams["redirect_uri"]) - assert.Nil(suite.T(), msg.SessionData) + if msg != nil { + assert.Equal(suite.T(), oauth2const.TypeInitialAuthorizationRequest, msg.RequestType) + assert.Equal(suite.T(), "test-client", msg.RequestQueryParams["client_id"]) + assert.Equal(suite.T(), "https://example.com", msg.RequestQueryParams["redirect_uri"]) + assert.Empty(suite.T(), msg.SessionDataKey) + } } func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForGetRequest_ParseFormError() { @@ -93,7 +105,7 @@ func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForGetRequest_ParseFo } func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForPostRequest_MissingSessionDataKey() { - postData := model.AuthZPostRequest{ + postData := AuthZPostRequest{ SessionDataKey: "", // Missing session data key Assertion: "test-assertion", } @@ -110,7 +122,7 @@ func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForPostRequest_Missin } func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessageForPostRequest_MissingAssertion() { - postData := model.AuthZPostRequest{ + postData := AuthZPostRequest{ SessionDataKey: "test-session-key", Assertion: "", // Missing assertion } @@ -154,7 +166,7 @@ func (suite *AuthorizeHandlerTestSuite) TestGetOAuthMessage_NilResponseWriter() func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_Success() { // Create a valid OAuth message with session data - sessionData := &sessionmodel.SessionData{ + sessionData := &SessionData{ OAuthParameters: oauth2model.OAuthParameters{ ClientID: "test-client", RedirectURI: "https://client.example.com/callback", @@ -163,11 +175,7 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_Success() { AuthTime: time.Now(), } - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - } - - result, err := getAuthorizationCode(oAuthMessage, "test-user") + result, err := createAuthorizationCode(sessionData, "test-user") assert.NoError(suite.T(), err) assert.NotEmpty(suite.T(), result.CodeID) @@ -176,13 +184,13 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_Success() { assert.Equal(suite.T(), "https://client.example.com/callback", result.RedirectURI) assert.Equal(suite.T(), "test-user", result.AuthorizedUserID) assert.Equal(suite.T(), "read write", result.Scopes) - assert.Equal(suite.T(), constants.AuthCodeStateActive, result.State) + assert.Equal(suite.T(), AuthCodeStateActive, result.State) assert.NotZero(suite.T(), result.TimeCreated) assert.True(suite.T(), result.ExpiryTime.After(result.TimeCreated)) } func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_MissingClientID() { - sessionData := &sessionmodel.SessionData{ + sessionData := &SessionData{ OAuthParameters: oauth2model.OAuthParameters{ ClientID: "", // Missing client ID RedirectURI: "https://client.example.com/callback", @@ -190,22 +198,15 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_MissingClientID AuthTime: time.Now(), } - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - RequestQueryParams: map[string]string{ - "redirect_uri": "https://client.example.com/callback", - }, - } - - result, err := getAuthorizationCode(oAuthMessage, "test-user") + result, err := createAuthorizationCode(sessionData, "test-user") assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "client_id or redirect_uri is missing") - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), AuthorizationCode{}, result) } func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_MissingRedirectURI() { - sessionData := &sessionmodel.SessionData{ + sessionData := &SessionData{ OAuthParameters: oauth2model.OAuthParameters{ ClientID: "test-client", RedirectURI: "", // Missing redirect URI @@ -213,22 +214,15 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_MissingRedirect AuthTime: time.Now(), } - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - RequestQueryParams: map[string]string{ - "client_id": "test-client", - }, - } - - result, err := getAuthorizationCode(oAuthMessage, "test-user") + result, err := createAuthorizationCode(sessionData, "test-user") assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "client_id or redirect_uri is missing") - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), AuthorizationCode{}, result) } func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_EmptyUserID() { - sessionData := &sessionmodel.SessionData{ + sessionData := &SessionData{ OAuthParameters: oauth2model.OAuthParameters{ ClientID: "test-client", RedirectURI: "https://client.example.com/callback", @@ -236,19 +230,15 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_EmptyUserID() { AuthTime: time.Now(), } - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - } - - result, err := getAuthorizationCode(oAuthMessage, "") // Empty user ID + result, err := createAuthorizationCode(sessionData, "") // Empty user ID assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "authenticated user not found") - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), AuthorizationCode{}, result) } func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_ZeroAuthTime() { - sessionData := &sessionmodel.SessionData{ + sessionData := &SessionData{ OAuthParameters: oauth2model.OAuthParameters{ ClientID: "test-client", RedirectURI: "https://client.example.com/callback", @@ -256,43 +246,11 @@ func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_ZeroAuthTime() AuthTime: time.Time{}, // Zero time } - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - } - - result, err := getAuthorizationCode(oAuthMessage, "test-user") + result, err := createAuthorizationCode(sessionData, "test-user") assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "authentication time is not set") - assert.Equal(suite.T(), model.AuthorizationCode{}, result) -} - -func (suite *AuthorizeHandlerTestSuite) TestGetAuthorizationCode_FallbackToQueryParams() { - // Test fallback to query params when session data is missing values - sessionData := &sessionmodel.SessionData{ - OAuthParameters: oauth2model.OAuthParameters{ - ClientID: "", // Missing in session, should fallback to query params - RedirectURI: "", // Missing in session, should fallback to query params - Scopes: "", // Missing in session, should fallback to query params - }, - AuthTime: time.Now(), - } - - oAuthMessage := &model.OAuthMessage{ - SessionData: sessionData, - RequestQueryParams: map[string]string{ - "client_id": "fallback-client", - "redirect_uri": "https://fallback.example.com/callback", - "scope": "fallback-scope", - }, - } - - result, err := getAuthorizationCode(oAuthMessage, "test-user") - - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "fallback-client", result.ClientID) - assert.Equal(suite.T(), "https://fallback.example.com/callback", result.RedirectURI) - assert.Equal(suite.T(), "fallback-scope", result.Scopes) + assert.Equal(suite.T(), AuthorizationCode{}, result) } func (suite *AuthorizeHandlerTestSuite) TestGetLoginPageRedirectURI_Success() { diff --git a/backend/internal/system/services/authorizationservice.go b/backend/internal/oauth/oauth2/authz/init.go similarity index 55% rename from backend/internal/system/services/authorizationservice.go rename to backend/internal/oauth/oauth2/authz/init.go index 2194dc1e..e57ccb67 100644 --- a/backend/internal/system/services/authorizationservice.go +++ b/backend/internal/oauth/oauth2/authz/init.go @@ -16,45 +16,43 @@ * under the License. */ -package services +package authz import ( "net/http" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" + "github.com/asgardeo/thunder/internal/application" + "github.com/asgardeo/thunder/internal/system/jwt" "github.com/asgardeo/thunder/internal/system/middleware" ) -// AuthorizationService defines the service for handling OAuth2 authorization requests. -type AuthorizationService struct { - authHandler authz.AuthorizeHandlerInterface +// Initialize initializes the authorization handler and registers its routes. +func Initialize( + mux *http.ServeMux, + applicationService application.ApplicationServiceInterface, + jwtService jwt.JWTServiceInterface, +) AuthorizeServiceInterface { + authzCodeStore := newAuthorizationCodeStore() + authzService := newAuthorizeService(authzCodeStore) + authzHandler := newAuthorizeHandler(applicationService, jwtService, authzCodeStore) + registerRoutes(mux, authzHandler) + return authzService } -// NewAuthorizationService creates a new instance of AuthorizationService. -func NewAuthorizationService(mux *http.ServeMux) ServiceInterface { - instance := &AuthorizationService{ - authHandler: authz.NewAuthorizeHandler(), - } - instance.RegisterRoutes(mux) - - return instance -} - -// RegisterRoutes registers the routes for the AuthorizationService. -func (s *AuthorizationService) RegisterRoutes(mux *http.ServeMux) { - opts1 := middleware.CORSOptions{ +// registerRoutes registers the routes for OAuth2 authorization operations. +func registerRoutes(mux *http.ServeMux, authzHandler AuthorizeHandlerInterface) { + opts := middleware.CORSOptions{ AllowedMethods: "GET, POST", AllowedHeaders: "Content-Type, Authorization", AllowCredentials: true, } mux.HandleFunc(middleware.WithCORS("GET /oauth2/authorize", - s.authHandler.HandleAuthorizeGetRequest, opts1)) + authzHandler.HandleAuthorizeGetRequest, opts)) mux.HandleFunc(middleware.WithCORS("POST /oauth2/authorize", - s.authHandler.HandleAuthorizePostRequest, opts1)) - + authzHandler.HandleAuthorizePostRequest, opts)) mux.HandleFunc(middleware.WithCORS("OPTIONS /oauth2/authorize", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) - }, opts1)) + }, opts)) } diff --git a/backend/internal/oauth/oauth2/authz/model/model.go b/backend/internal/oauth/oauth2/authz/model.go similarity index 88% rename from backend/internal/oauth/oauth2/authz/model/model.go rename to backend/internal/oauth/oauth2/authz/model.go index b5b01a89..f9b8701a 100644 --- a/backend/internal/oauth/oauth2/authz/model/model.go +++ b/backend/internal/oauth/oauth2/authz/model.go @@ -16,19 +16,16 @@ * under the License. */ -// Package model defines the data structures for OAuth2 authorization. -package model +package authz import ( "time" - - sessionmodel "github.com/asgardeo/thunder/internal/oauth/session/model" ) // OAuthMessage represents the OAuth message. type OAuthMessage struct { RequestType string - SessionData *sessionmodel.SessionData + SessionDataKey string RequestQueryParams map[string]string RequestBodyParams map[string]string } diff --git a/backend/internal/oauth/oauth2/authz/service.go b/backend/internal/oauth/oauth2/authz/service.go new file mode 100644 index 00000000..3c763ebb --- /dev/null +++ b/backend/internal/oauth/oauth2/authz/service.go @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +// Package authz implements the OAuth2 authorization functionality. +package authz + +import ( + "errors" +) + +// AuthorizeServiceInterface defines the interface for authorization services. +type AuthorizeServiceInterface interface { + GetAuthorizationCodeDetails(clientID string, code string) (*AuthorizationCode, error) +} + +// authorizeService implements the AuthorizeService for managing OAuth2 authorization flows. +type authorizeService struct { + authzStore AuthorizationCodeStoreInterface +} + +// newAuthorizeService creates a new instance of authorizeService with injected dependencies. +func newAuthorizeService(authzStore AuthorizationCodeStoreInterface) AuthorizeServiceInterface { + return &authorizeService{ + authzStore: authzStore, + } +} + +func (as *authorizeService) GetAuthorizationCodeDetails( + clientID string, code string) (*AuthorizationCode, error) { + authCode, err := as.authzStore.GetAuthorizationCode(clientID, code) + if err != nil || authCode.Code == "" { + return nil, errors.New("invalid authorization code") + } + + // Invalidate the authorization code after use. + err = as.authzStore.DeactivateAuthorizationCode(authCode) + if err != nil { + return nil, errors.New("failed to invalidate authorization code") + } + return &authCode, nil +} diff --git a/backend/internal/oauth/oauth2/authz/service_test.go b/backend/internal/oauth/oauth2/authz/service_test.go new file mode 100644 index 00000000..3b34d8a9 --- /dev/null +++ b/backend/internal/oauth/oauth2/authz/service_test.go @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package authz + +import ( + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type AuthorizeServiceTestSuite struct { + suite.Suite + service AuthorizeServiceInterface + mockAuthzStore *AuthorizationCodeStoreInterfaceMock + testAuthzCode AuthorizationCode + testClientID string + testCode string +} + +func TestAuthorizeServiceTestSuite(t *testing.T) { + suite.Run(t, new(AuthorizeServiceTestSuite)) +} + +func (suite *AuthorizeServiceTestSuite) SetupTest() { + suite.mockAuthzStore = NewAuthorizationCodeStoreInterfaceMock(suite.T()) + suite.service = newAuthorizeService(suite.mockAuthzStore) + + suite.testClientID = "test-client-id" + suite.testCode = "test-auth-code" + + suite.testAuthzCode = AuthorizationCode{ + CodeID: "test-code-id", + Code: suite.testCode, + ClientID: suite.testClientID, + RedirectURI: "https://client.example.com/callback", + AuthorizedUserID: "test-user-id", + TimeCreated: time.Now().Add(-5 * time.Minute), + ExpiryTime: time.Now().Add(5 * time.Minute), + Scopes: "read write", + State: AuthCodeStateActive, + } +} + +func (suite *AuthorizeServiceTestSuite) TestNewAuthorizeService() { + service := newAuthorizeService(suite.mockAuthzStore) + assert.NotNil(suite.T(), service) + assert.Implements(suite.T(), (*AuthorizeServiceInterface)(nil), service) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_Success() { + // Mock store to return valid authorization code + suite.mockAuthzStore.On("GetAuthorizationCode", suite.testClientID, suite.testCode). + Return(suite.testAuthzCode, nil) + suite.mockAuthzStore.On("DeactivateAuthorizationCode", suite.testAuthzCode). + Return(nil) + + result, err := suite.service.GetAuthorizationCodeDetails(suite.testClientID, suite.testCode) + + assert.Nil(suite.T(), err) + assert.NotNil(suite.T(), result) + assert.Equal(suite.T(), suite.testAuthzCode.Code, result.Code) + assert.Equal(suite.T(), suite.testAuthzCode.ClientID, result.ClientID) + assert.Equal(suite.T(), suite.testAuthzCode.AuthorizedUserID, result.AuthorizedUserID) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_StoreError() { + // Mock store to return error + suite.mockAuthzStore.On("GetAuthorizationCode", suite.testClientID, suite.testCode). + Return(AuthorizationCode{}, errors.New("database error")) + + result, err := suite.service.GetAuthorizationCodeDetails(suite.testClientID, suite.testCode) + + assert.Nil(suite.T(), result) + assert.NotNil(suite.T(), err) + assert.Equal(suite.T(), "invalid authorization code", err.Error()) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_EmptyCode() { + // Mock store to return authorization code with empty code string + emptyAuthzCode := suite.testAuthzCode + emptyAuthzCode.Code = "" + + suite.mockAuthzStore.On("GetAuthorizationCode", suite.testClientID, suite.testCode). + Return(emptyAuthzCode, nil) + + result, err := suite.service.GetAuthorizationCodeDetails(suite.testClientID, suite.testCode) + + assert.Nil(suite.T(), result) + assert.NotNil(suite.T(), err) + assert.Equal(suite.T(), "invalid authorization code", err.Error()) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_DeactivationError() { + // Mock store to return valid code but fail on deactivation + suite.mockAuthzStore.On("GetAuthorizationCode", suite.testClientID, suite.testCode). + Return(suite.testAuthzCode, nil) + suite.mockAuthzStore.On("DeactivateAuthorizationCode", suite.testAuthzCode). + Return(errors.New("deactivation failed")) + + result, err := suite.service.GetAuthorizationCodeDetails(suite.testClientID, suite.testCode) + + assert.Nil(suite.T(), result) + assert.NotNil(suite.T(), err) + assert.Equal(suite.T(), "failed to invalidate authorization code", err.Error()) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_EmptyClientID() { + // Mock store to be called with empty client ID + suite.mockAuthzStore.On("GetAuthorizationCode", "", suite.testCode). + Return(AuthorizationCode{}, errors.New("invalid client")) + + result, err := suite.service.GetAuthorizationCodeDetails("", suite.testCode) + + assert.Nil(suite.T(), result) + assert.NotNil(suite.T(), err) + assert.Equal(suite.T(), "invalid authorization code", err.Error()) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} + +func (suite *AuthorizeServiceTestSuite) TestGetAuthorizationCodeDetails_EmptyCodeString() { + // Mock store to be called with empty code string + suite.mockAuthzStore.On("GetAuthorizationCode", suite.testClientID, ""). + Return(AuthorizationCode{}, errors.New("invalid code")) + + result, err := suite.service.GetAuthorizationCodeDetails(suite.testClientID, "") + + assert.Nil(suite.T(), result) + assert.NotNil(suite.T(), err) + assert.Equal(suite.T(), "invalid authorization code", err.Error()) + + suite.mockAuthzStore.AssertExpectations(suite.T()) +} diff --git a/backend/internal/oauth/oauth2/authz/sessionDataStoreInterface_mock_test.go b/backend/internal/oauth/oauth2/authz/sessionDataStoreInterface_mock_test.go new file mode 100644 index 00000000..6d6b32bd --- /dev/null +++ b/backend/internal/oauth/oauth2/authz/sessionDataStoreInterface_mock_test.go @@ -0,0 +1,220 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authz + +import ( + mock "github.com/stretchr/testify/mock" +) + +// newSessionDataStoreInterfaceMock creates a new instance of sessionDataStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newSessionDataStoreInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *sessionDataStoreInterfaceMock { + mock := &sessionDataStoreInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// sessionDataStoreInterfaceMock is an autogenerated mock type for the sessionDataStoreInterface type +type sessionDataStoreInterfaceMock struct { + mock.Mock +} + +type sessionDataStoreInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *sessionDataStoreInterfaceMock) EXPECT() *sessionDataStoreInterfaceMock_Expecter { + return &sessionDataStoreInterfaceMock_Expecter{mock: &_m.Mock} +} + +// AddSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) AddSession(value SessionData) string { + ret := _mock.Called(value) + + if len(ret) == 0 { + panic("no return value specified for AddSession") + } + + var r0 string + if returnFunc, ok := ret.Get(0).(func(SessionData) string); ok { + r0 = returnFunc(value) + } else { + r0 = ret.Get(0).(string) + } + return r0 +} + +// sessionDataStoreInterfaceMock_AddSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddSession' +type sessionDataStoreInterfaceMock_AddSession_Call struct { + *mock.Call +} + +// AddSession is a helper method to define mock.On call +// - value SessionData +func (_e *sessionDataStoreInterfaceMock_Expecter) AddSession(value interface{}) *sessionDataStoreInterfaceMock_AddSession_Call { + return &sessionDataStoreInterfaceMock_AddSession_Call{Call: _e.mock.On("AddSession", value)} +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) Run(run func(value SessionData)) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 SessionData + if args[0] != nil { + arg0 = args[0].(SessionData) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) Return(s string) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Return(s) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) RunAndReturn(run func(value SessionData) string) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Return(run) + return _c +} + +// ClearSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) ClearSession(key string) { + _mock.Called(key) + return +} + +// sessionDataStoreInterfaceMock_ClearSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSession' +type sessionDataStoreInterfaceMock_ClearSession_Call struct { + *mock.Call +} + +// ClearSession is a helper method to define mock.On call +// - key string +func (_e *sessionDataStoreInterfaceMock_Expecter) ClearSession(key interface{}) *sessionDataStoreInterfaceMock_ClearSession_Call { + return &sessionDataStoreInterfaceMock_ClearSession_Call{Call: _e.mock.On("ClearSession", key)} +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) Run(run func(key string)) *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) Return() *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Call.Return() + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) RunAndReturn(run func(key string)) *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Run(run) + return _c +} + +// ClearSessionStore provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) ClearSessionStore() { + _mock.Called() + return +} + +// sessionDataStoreInterfaceMock_ClearSessionStore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSessionStore' +type sessionDataStoreInterfaceMock_ClearSessionStore_Call struct { + *mock.Call +} + +// ClearSessionStore is a helper method to define mock.On call +func (_e *sessionDataStoreInterfaceMock_Expecter) ClearSessionStore() *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + return &sessionDataStoreInterfaceMock_ClearSessionStore_Call{Call: _e.mock.On("ClearSessionStore")} +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) Run(run func()) *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) Return() *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Call.Return() + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) RunAndReturn(run func()) *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Run(run) + return _c +} + +// GetSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) GetSession(key string) (bool, SessionData) { + ret := _mock.Called(key) + + if len(ret) == 0 { + panic("no return value specified for GetSession") + } + + var r0 bool + var r1 SessionData + if returnFunc, ok := ret.Get(0).(func(string) (bool, SessionData)); ok { + return returnFunc(key) + } + if returnFunc, ok := ret.Get(0).(func(string) bool); ok { + r0 = returnFunc(key) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(string) SessionData); ok { + r1 = returnFunc(key) + } else { + r1 = ret.Get(1).(SessionData) + } + return r0, r1 +} + +// sessionDataStoreInterfaceMock_GetSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSession' +type sessionDataStoreInterfaceMock_GetSession_Call struct { + *mock.Call +} + +// GetSession is a helper method to define mock.On call +// - key string +func (_e *sessionDataStoreInterfaceMock_Expecter) GetSession(key interface{}) *sessionDataStoreInterfaceMock_GetSession_Call { + return &sessionDataStoreInterfaceMock_GetSession_Call{Call: _e.mock.On("GetSession", key)} +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) Run(run func(key string)) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) Return(b bool, sessionData SessionData) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Return(b, sessionData) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) RunAndReturn(run func(key string) (bool, SessionData)) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/oauth/session/store/sessiondatastore.go b/backend/internal/oauth/oauth2/authz/session_store.go similarity index 60% rename from backend/internal/oauth/session/store/sessiondatastore.go rename to backend/internal/oauth/oauth2/authz/session_store.go index 90828d08..926ca96d 100644 --- a/backend/internal/oauth/session/store/sessiondatastore.go +++ b/backend/internal/oauth/oauth2/authz/session_store.go @@ -16,73 +16,68 @@ * under the License. */ -// Package store provides functionality for managing auth session data storage. -package store +package authz import ( "sync" "time" - "github.com/asgardeo/thunder/internal/oauth/session/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/model" + "github.com/asgardeo/thunder/internal/system/utils" ) -// SessionDataStoreInterface defines the interface for session data storage. -type SessionDataStoreInterface interface { - AddSession(key string, value model.SessionData) - GetSession(key string) (bool, model.SessionData) +// SessionData holds OAuth session information including parameters and authentication time. +type SessionData struct { + OAuthParameters model.OAuthParameters + AuthTime time.Time +} + +// sessionDataStoreInterface defines the interface for session data storage. +type sessionDataStoreInterface interface { + AddSession(value SessionData) string + GetSession(key string) (bool, SessionData) ClearSession(key string) ClearSessionStore() } // sessionStoreEntry represents an entry in the session data store. type sessionStoreEntry struct { - sessionData model.SessionData + sessionData SessionData expiryTime time.Time } -// SessionDataStore provides the session data store functionality. -type SessionDataStore struct { +// sessionDataStore provides the session data store functionality. +type sessionDataStore struct { sessionStore map[string]sessionStoreEntry validityPeriod time.Duration mu sync.RWMutex } -var ( - instance *SessionDataStore - once sync.Once -) - -// GetSessionDataStore returns a singleton instance of SessionDataStore. -func GetSessionDataStore() SessionDataStoreInterface { - once.Do(func() { - instance = &SessionDataStore{ - sessionStore: make(map[string]sessionStoreEntry), - validityPeriod: 10 * time.Minute, // Set a default validity period. - } - }) - - return instance +// newSessionDataStore creates a new instance of sessionDataStore with injected dependencies. +func newSessionDataStore() sessionDataStoreInterface { + return &sessionDataStore{ + sessionStore: make(map[string]sessionStoreEntry), + validityPeriod: 10 * time.Minute, // Set a default validity period. + } } // AddSession adds a session data entry to the session store. -func (sds *SessionDataStore) AddSession(key string, value model.SessionData) { - if key == "" { - return - } - +func (sds *sessionDataStore) AddSession(value SessionData) string { sds.mu.Lock() defer sds.mu.Unlock() + key := utils.GenerateUUID() sds.sessionStore[key] = sessionStoreEntry{ sessionData: value, expiryTime: time.Now().Add(sds.validityPeriod), } + return key } // GetSession retrieves a session data entry from the session store. -func (sdc *SessionDataStore) GetSession(key string) (bool, model.SessionData) { +func (sdc *sessionDataStore) GetSession(key string) (bool, SessionData) { if key == "" { - return false, model.SessionData{} + return false, SessionData{} } sdc.mu.RLock() @@ -100,11 +95,11 @@ func (sdc *SessionDataStore) GetSession(key string) (bool, model.SessionData) { } } - return false, model.SessionData{} + return false, SessionData{} } // ClearSession removes a specific session data entry from the session store. -func (sdc *SessionDataStore) ClearSession(key string) { +func (sdc *sessionDataStore) ClearSession(key string) { if key == "" { return } @@ -115,7 +110,7 @@ func (sdc *SessionDataStore) ClearSession(key string) { } // ClearSessionStore removes all session data entries from the session store. -func (sdc *SessionDataStore) ClearSessionStore() { +func (sdc *sessionDataStore) ClearSessionStore() { sdc.mu.Lock() defer sdc.mu.Unlock() diff --git a/backend/internal/oauth/session/store/sessiondatastore_test.go b/backend/internal/oauth/oauth2/authz/session_store_test.go similarity index 55% rename from backend/internal/oauth/session/store/sessiondatastore_test.go rename to backend/internal/oauth/oauth2/authz/session_store_test.go index 52e50eed..7437c1b1 100644 --- a/backend/internal/oauth/session/store/sessiondatastore_test.go +++ b/backend/internal/oauth/oauth2/authz/session_store_test.go @@ -16,7 +16,7 @@ * under the License. */ -package store +package authz import ( "sync" @@ -26,18 +26,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - authncm "github.com/asgardeo/thunder/internal/authn/common" "github.com/asgardeo/thunder/internal/oauth/oauth2/model" - sessionmodel "github.com/asgardeo/thunder/internal/oauth/session/model" -) - -const ( - testSessionKey = "test-session-key" ) type SessionDataStoreTestSuite struct { suite.Suite - store SessionDataStoreInterface + store sessionDataStoreInterface } func TestSessionDataStoreSuite(t *testing.T) { @@ -45,10 +39,7 @@ func TestSessionDataStoreSuite(t *testing.T) { } func (suite *SessionDataStoreTestSuite) SetupTest() { - instance = nil - once = sync.Once{} - - suite.store = GetSessionDataStore() + suite.store = newSessionDataStore() suite.store.ClearSessionStore() } @@ -59,59 +50,32 @@ func (suite *SessionDataStoreTestSuite) TearDownTest() { } func (suite *SessionDataStoreTestSuite) TestGetSessionDataStore() { - store := GetSessionDataStore() + store := newSessionDataStore() assert.NotNil(suite.T(), store) - assert.Implements(suite.T(), (*SessionDataStoreInterface)(nil), store) -} - -func (suite *SessionDataStoreTestSuite) TestGetSessionDataStoreSingleton() { - store1 := GetSessionDataStore() - store2 := GetSessionDataStore() - assert.Same(suite.T(), store1, store2) + assert.Implements(suite.T(), (*sessionDataStoreInterface)(nil), store) } func (suite *SessionDataStoreTestSuite) TestAddSession() { - sessionData := sessionmodel.SessionData{ + sessionData := SessionData{ OAuthParameters: model.OAuthParameters{ - SessionDataKey: testSessionKey, - ClientID: "test-client", - RedirectURI: "https://example.com/callback", - ResponseType: "code", - Scopes: "read write", - State: "test-state", + ClientID: "test-client", + RedirectURI: "https://example.com/callback", + ResponseType: "code", + Scopes: "read write", + State: "test-state", }, AuthTime: time.Now(), - AuthenticatedUser: authncm.AuthenticatedUser{ - IsAuthenticated: true, - UserID: "user123", - Attributes: map[string]interface{}{ - "username": "testuser", - "email": "test@example.com", - }, - }, } - suite.store.AddSession(testSessionKey, sessionData) - found, retrievedData := suite.store.GetSession(testSessionKey) + key := suite.store.AddSession(sessionData) + found, retrievedData := suite.store.GetSession(key) assert.True(suite.T(), found) assert.Equal(suite.T(), sessionData.OAuthParameters.ClientID, retrievedData.OAuthParameters.ClientID) - assert.Equal(suite.T(), sessionData.AuthenticatedUser.UserID, retrievedData.AuthenticatedUser.UserID) -} - -func (suite *SessionDataStoreTestSuite) TestAddSessionWithEmptyKey() { - sessionData := sessionmodel.SessionData{ - OAuthParameters: model.OAuthParameters{ - ClientID: "test-client", - }, - } - - suite.store.AddSession("", sessionData) - found, _ := suite.store.GetSession("") - assert.False(suite.T(), found) + assert.Equal(suite.T(), sessionData.AuthTime, retrievedData.AuthTime) } func (suite *SessionDataStoreTestSuite) TestGetSession() { - sessionData := sessionmodel.SessionData{ + sessionData := SessionData{ OAuthParameters: model.OAuthParameters{ ClientID: "test-client", State: "test-state", @@ -119,8 +83,8 @@ func (suite *SessionDataStoreTestSuite) TestGetSession() { AuthTime: time.Now(), } - suite.store.AddSession(testSessionKey, sessionData) - found, retrievedData := suite.store.GetSession(testSessionKey) + key := suite.store.AddSession(sessionData) + found, retrievedData := suite.store.GetSession(key) assert.True(suite.T(), found) assert.Equal(suite.T(), sessionData.OAuthParameters.ClientID, retrievedData.OAuthParameters.ClientID) assert.Equal(suite.T(), sessionData.OAuthParameters.State, retrievedData.OAuthParameters.State) @@ -137,19 +101,19 @@ func (suite *SessionDataStoreTestSuite) TestGetSessionWithEmptyKey() { } func (suite *SessionDataStoreTestSuite) TestClearSession() { - sessionData := sessionmodel.SessionData{ + sessionData := SessionData{ OAuthParameters: model.OAuthParameters{ ClientID: "test-client", }, } - suite.store.AddSession(testSessionKey, sessionData) + key := suite.store.AddSession(sessionData) - found, _ := suite.store.GetSession(testSessionKey) + found, _ := suite.store.GetSession(key) assert.True(suite.T(), found) - suite.store.ClearSession(testSessionKey) - found, _ = suite.store.GetSession(testSessionKey) + suite.store.ClearSession(key) + found, _ = suite.store.GetSession(key) assert.False(suite.T(), found) } @@ -158,15 +122,15 @@ func (suite *SessionDataStoreTestSuite) TestClearSessionWithEmptyKey() { } func (suite *SessionDataStoreTestSuite) TestClearSessionStore() { - keys := []string{"key1", "key2", "key3"} - sessionData := sessionmodel.SessionData{ - OAuthParameters: model.OAuthParameters{ - ClientID: "test-client", - }, - } + clientIDs := []string{"client1", "client2", "client3"} + keys := make([]string, 0, len(clientIDs)) - for _, key := range keys { - suite.store.AddSession(key, sessionData) + for _, clientID := range clientIDs { + keys = append(keys, suite.store.AddSession(SessionData{ + OAuthParameters: model.OAuthParameters{ + ClientID: clientID, + }, + })) } for _, key := range keys { @@ -182,67 +146,82 @@ func (suite *SessionDataStoreTestSuite) TestClearSessionStore() { } func (suite *SessionDataStoreTestSuite) TestSessionExpiry() { - key := "test-expiry-key" - sessionData := sessionmodel.SessionData{ + sessionData := SessionData{ OAuthParameters: model.OAuthParameters{ ClientID: "test-client", }, AuthTime: time.Now(), } - suite.store.AddSession(key, sessionData) + key := suite.store.AddSession(sessionData) found, _ := suite.store.GetSession(key) assert.True(suite.T(), found) } func (suite *SessionDataStoreTestSuite) TestConcurrentAccess() { - key := "concurrent-test-key" - sessionData := sessionmodel.SessionData{ - OAuthParameters: model.OAuthParameters{ - ClientID: "test-client", - }, - } - + numGoroutines := 100 var wg sync.WaitGroup - numGoroutines := 10 + keys := make([]string, numGoroutines) + var keysMutex sync.Mutex + // Test concurrent AddSession operations wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func(index int) { defer wg.Done() - keyWithIndex := key + string(rune('0'+index)) - suite.store.AddSession(keyWithIndex, sessionData) + sessionData := SessionData{ + OAuthParameters: model.OAuthParameters{ + ClientID: "test-client-" + string(rune('0'+index%10)), + State: "state-" + string(rune('0'+index%10)), + }, + AuthTime: time.Now(), + } + key := suite.store.AddSession(sessionData) + + keysMutex.Lock() + keys[index] = key + keysMutex.Unlock() }(i) } wg.Wait() + // Verify all keys are unique + keyMap := make(map[string]bool) + for _, key := range keys { + assert.NotEmpty(suite.T(), key, "Generated key should not be empty") + assert.False(suite.T(), keyMap[key], "Keys should be unique, found duplicate: "+key) + keyMap[key] = true + } + assert.Equal(suite.T(), numGoroutines, len(keyMap), "All keys should be unique") + + // Test concurrent GetSession operations wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func(index int) { defer wg.Done() - keyWithIndex := key + string(rune('0'+index)) - found, _ := suite.store.GetSession(keyWithIndex) - assert.True(suite.T(), found) + key := keys[index] + found, retrievedData := suite.store.GetSession(key) + assert.True(suite.T(), found, "Session should be found for key: "+key) + assert.NotEmpty(suite.T(), retrievedData.OAuthParameters.ClientID) }(i) } wg.Wait() - // Clear sessions concurrently + // Test concurrent ClearSession operations wg.Add(numGoroutines) for i := 0; i < numGoroutines; i++ { go func(index int) { defer wg.Done() - keyWithIndex := key + string(rune('0'+index)) - suite.store.ClearSession(keyWithIndex) + key := keys[index] + suite.store.ClearSession(key) }(i) } wg.Wait() // Verify all sessions are cleared - for i := 0; i < numGoroutines; i++ { - keyWithIndex := key + string(rune('0'+i)) - found, _ := suite.store.GetSession(keyWithIndex) - assert.False(suite.T(), found) + for i, key := range keys { + found, _ := suite.store.GetSession(key) + assert.False(suite.T(), found, "Session should be cleared for key at index %d: %s", i, key) } } diff --git a/backend/internal/oauth/oauth2/authz/store/store.go b/backend/internal/oauth/oauth2/authz/store.go similarity index 65% rename from backend/internal/oauth/oauth2/authz/store/store.go rename to backend/internal/oauth/oauth2/authz/store.go index 2e4fd901..4a3c0ed0 100644 --- a/backend/internal/oauth/oauth2/authz/store/store.go +++ b/backend/internal/oauth/oauth2/authz/store.go @@ -16,8 +16,7 @@ * under the License. */ -// Package store provides functionality for handling authorization code persistence and retrieval. -package store +package authz import ( "errors" @@ -25,40 +24,38 @@ import ( "strings" "time" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" "github.com/asgardeo/thunder/internal/system/database/provider" "github.com/asgardeo/thunder/internal/system/log" ) -const loggerComponentName = "AuthorizationCodeStore" +const storeLoggerComponentName = "AuthorizationCodeStore" // AuthorizationCodeStoreInterface defines the interface for managing authorization codes. type AuthorizationCodeStoreInterface interface { - InsertAuthorizationCode(authzCode model.AuthorizationCode) error - GetAuthorizationCode(clientID, authCode string) (model.AuthorizationCode, error) - DeactivateAuthorizationCode(authzCode model.AuthorizationCode) error - RevokeAuthorizationCode(authzCode model.AuthorizationCode) error - ExpireAuthorizationCode(authzCode model.AuthorizationCode) error + InsertAuthorizationCode(authzCode AuthorizationCode) error + GetAuthorizationCode(clientID, authCode string) (AuthorizationCode, error) + DeactivateAuthorizationCode(authzCode AuthorizationCode) error + RevokeAuthorizationCode(authzCode AuthorizationCode) error + ExpireAuthorizationCode(authzCode AuthorizationCode) error } -// AuthorizationCodeStore implements the AuthorizationCodeStoreInterface for managing authorization codes. -type AuthorizationCodeStore struct { - DBProvider provider.DBProviderInterface +// authorizationCodeStore implements the AuthorizationCodeStoreInterface for managing authorization codes. +type authorizationCodeStore struct { + dbProvider provider.DBProviderInterface } -// NewAuthorizationCodeStore creates a new instance of AuthorizationCodeStore. -func NewAuthorizationCodeStore() AuthorizationCodeStoreInterface { - return &AuthorizationCodeStore{ - DBProvider: provider.GetDBProvider(), +// newAuthorizationCodeStore creates a new instance of authorizationCodeStore with injected dependencies. +func newAuthorizationCodeStore() AuthorizationCodeStoreInterface { + return &authorizationCodeStore{ + dbProvider: provider.GetDBProvider(), } } // InsertAuthorizationCode inserts a new authorization code into the database. -func (acs *AuthorizationCodeStore) InsertAuthorizationCode(authzCode model.AuthorizationCode) error { - logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) +func (acs *authorizationCodeStore) InsertAuthorizationCode(authzCode AuthorizationCode) error { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, storeLoggerComponentName)) - dbClient, err := acs.DBProvider.GetDBClient("runtime") + dbClient, err := acs.dbProvider.GetDBClient("runtime") if err != nil { logger.Error("Failed to get database client", log.Error(err)) return err @@ -71,7 +68,7 @@ func (acs *AuthorizationCodeStore) InsertAuthorizationCode(authzCode model.Autho } // Insert authorization code. - _, err = tx.Exec(constants.QueryInsertAuthorizationCode.Query, authzCode.CodeID, authzCode.Code, + _, err = tx.Exec(queryInsertAuthorizationCode.Query, authzCode.CodeID, authzCode.Code, authzCode.ClientID, authzCode.RedirectURI, authzCode.AuthorizedUserID, authzCode.TimeCreated, authzCode.ExpiryTime, authzCode.State, authzCode.CodeChallenge, authzCode.CodeChallengeMethod) if err != nil { @@ -84,7 +81,7 @@ func (acs *AuthorizationCodeStore) InsertAuthorizationCode(authzCode model.Autho } // Insert auth code scopes. - _, err = tx.Exec(constants.QueryInsertAuthorizationCodeScopes.Query, authzCode.CodeID, + _, err = tx.Exec(queryInsertAuthorizationCodeScopes.Query, authzCode.CodeID, authzCode.Scopes) if err != nil { logger.Error("Failed to insert authorization code scopes", log.Error(err)) @@ -105,39 +102,39 @@ func (acs *AuthorizationCodeStore) InsertAuthorizationCode(authzCode model.Autho } // GetAuthorizationCode retrieves an authorization code by client Id and authorization code. -func (acs *AuthorizationCodeStore) GetAuthorizationCode(clientID, authCode string) (model.AuthorizationCode, error) { - logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) +func (acs *authorizationCodeStore) GetAuthorizationCode(clientID, authCode string) (AuthorizationCode, error) { + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, storeLoggerComponentName)) - dbClient, err := acs.DBProvider.GetDBClient("runtime") + dbClient, err := acs.dbProvider.GetDBClient("runtime") if err != nil { logger.Error("Failed to get database client", log.Error(err)) - return model.AuthorizationCode{}, err + return AuthorizationCode{}, err } - results, err := dbClient.Query(constants.QueryGetAuthorizationCode, clientID, authCode) + results, err := dbClient.Query(queryGetAuthorizationCode, clientID, authCode) if err != nil { - return model.AuthorizationCode{}, fmt.Errorf("error while retrieving authorization code: %w", err) + return AuthorizationCode{}, fmt.Errorf("error while retrieving authorization code: %w", err) } if len(results) == 0 { - return model.AuthorizationCode{}, constants.ErrAuthorizationCodeNotFound + return AuthorizationCode{}, ErrAuthorizationCodeNotFound } row := results[0] codeID := row["code_id"].(string) if codeID == "" { - return model.AuthorizationCode{}, constants.ErrAuthorizationCodeNotFound + return AuthorizationCode{}, ErrAuthorizationCodeNotFound } // Handle time_created field. timeCreated, err := parseTimeField(row["time_created"], "time_created", logger) if err != nil { - return model.AuthorizationCode{}, err + return AuthorizationCode{}, err } // Handle expiry_time field. expiryTime, err := parseTimeField(row["expiry_time"], "expiry_time", logger) if err != nil { - return model.AuthorizationCode{}, err + return AuthorizationCode{}, err } // Extract PKCE fields @@ -151,16 +148,16 @@ func (acs *AuthorizationCodeStore) GetAuthorizationCode(clientID, authCode strin } // Retrieve authorized scopes for the authorization code. - scopeResults, err := dbClient.Query(constants.QueryGetAuthorizationCodeScopes, codeID) + scopeResults, err := dbClient.Query(queryGetAuthorizationCodeScopes, codeID) if err != nil { - return model.AuthorizationCode{}, fmt.Errorf("error while retrieving authorized scopes: %w", err) + return AuthorizationCode{}, fmt.Errorf("error while retrieving authorized scopes: %w", err) } scopes := "" if len(scopeResults) > 0 { scopes = scopeResults[0]["scope"].(string) } - return model.AuthorizationCode{ + return AuthorizationCode{ CodeID: codeID, Code: row["authorization_code"].(string), ClientID: clientID, @@ -176,32 +173,32 @@ func (acs *AuthorizationCodeStore) GetAuthorizationCode(clientID, authCode strin } // DeactivateAuthorizationCode deactivates an authorization code. -func (acs *AuthorizationCodeStore) DeactivateAuthorizationCode(authzCode model.AuthorizationCode) error { - return acs.updateAuthorizationCodeState(authzCode, constants.AuthCodeStateInactive) +func (acs *authorizationCodeStore) DeactivateAuthorizationCode(authzCode AuthorizationCode) error { + return acs.updateAuthorizationCodeState(authzCode, AuthCodeStateInactive) } // RevokeAuthorizationCode revokes an authorization code. -func (acs *AuthorizationCodeStore) RevokeAuthorizationCode(authzCode model.AuthorizationCode) error { - return acs.updateAuthorizationCodeState(authzCode, constants.AuthCodeStateRevoked) +func (acs *authorizationCodeStore) RevokeAuthorizationCode(authzCode AuthorizationCode) error { + return acs.updateAuthorizationCodeState(authzCode, AuthCodeStateRevoked) } // ExpireAuthorizationCode expires an authorization code. -func (acs *AuthorizationCodeStore) ExpireAuthorizationCode(authzCode model.AuthorizationCode) error { - return acs.updateAuthorizationCodeState(authzCode, constants.AuthCodeStateExpired) +func (acs *authorizationCodeStore) ExpireAuthorizationCode(authzCode AuthorizationCode) error { + return acs.updateAuthorizationCodeState(authzCode, AuthCodeStateExpired) } // updateAuthorizationCodeState updates the state of an authorization code. -func (acs *AuthorizationCodeStore) updateAuthorizationCodeState(authzCode model.AuthorizationCode, +func (acs *authorizationCodeStore) updateAuthorizationCodeState(authzCode AuthorizationCode, newState string) error { - logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, loggerComponentName)) + logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, storeLoggerComponentName)) - dbClient, err := acs.DBProvider.GetDBClient("runtime") + dbClient, err := acs.dbProvider.GetDBClient("runtime") if err != nil { logger.Error("Failed to get database client", log.Error(err)) return err } - _, err = dbClient.Execute(constants.QueryUpdateAuthorizationCodeState, newState, authzCode.CodeID) + _, err = dbClient.Execute(queryUpdateAuthorizationCodeState, newState, authzCode.CodeID) return err } diff --git a/backend/internal/oauth/oauth2/authz/store/store_test.go b/backend/internal/oauth/oauth2/authz/store_test.go similarity index 72% rename from backend/internal/oauth/oauth2/authz/store/store_test.go rename to backend/internal/oauth/oauth2/authz/store_test.go index 7b6dcfc1..1718934f 100644 --- a/backend/internal/oauth/oauth2/authz/store/store_test.go +++ b/backend/internal/oauth/oauth2/authz/store_test.go @@ -16,7 +16,7 @@ * under the License. */ -package store +package authz import ( "errors" @@ -26,8 +26,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" "github.com/asgardeo/thunder/internal/system/config" "github.com/asgardeo/thunder/tests/mocks/database/clientmock" "github.com/asgardeo/thunder/tests/mocks/database/modelmock" @@ -36,10 +34,10 @@ import ( type AuthorizationCodeStoreTestSuite struct { suite.Suite - mockDBProvider *providermock.DBProviderInterfaceMock + mockdbProvider *providermock.DBProviderInterfaceMock mockDBClient *clientmock.DBClientInterfaceMock - store *AuthorizationCodeStore - testAuthzCode model.AuthorizationCode + store *authorizationCodeStore + testAuthzCode AuthorizationCode } func TestAuthorizationCodeStoreTestSuite(t *testing.T) { @@ -61,14 +59,14 @@ func (suite *AuthorizationCodeStoreTestSuite) SetupTest() { } _ = config.InitializeThunderRuntime("test", testConfig) - suite.mockDBProvider = &providermock.DBProviderInterfaceMock{} + suite.mockdbProvider = &providermock.DBProviderInterfaceMock{} suite.mockDBClient = &clientmock.DBClientInterfaceMock{} - suite.store = &AuthorizationCodeStore{ - DBProvider: suite.mockDBProvider, + suite.store = &authorizationCodeStore{ + dbProvider: suite.mockdbProvider, } - suite.testAuthzCode = model.AuthorizationCode{ + suite.testAuthzCode = AuthorizationCode{ CodeID: "test-code-id", Code: "test-code", ClientID: "test-client-id", @@ -77,14 +75,14 @@ func (suite *AuthorizationCodeStoreTestSuite) SetupTest() { TimeCreated: time.Now(), ExpiryTime: time.Now().Add(10 * time.Minute), Scopes: "read write", - State: constants.AuthCodeStateActive, + State: AuthCodeStateActive, CodeChallenge: "", CodeChallengeMethod: "", } } -func (suite *AuthorizationCodeStoreTestSuite) TestNewAuthorizationCodeStore() { - store := NewAuthorizationCodeStore() +func (suite *AuthorizationCodeStoreTestSuite) TestnewAuthorizationCodeStore() { + store := newAuthorizationCodeStore() assert.NotNil(suite.T(), store) assert.Implements(suite.T(), (*AuthorizationCodeStoreInterface)(nil), store) } @@ -92,17 +90,17 @@ func (suite *AuthorizationCodeStoreTestSuite) TestNewAuthorizationCodeStore() { func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_Success() { mockTx := &modelmock.TxInterfaceMock{} - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) suite.mockDBClient.On("BeginTx").Return(mockTx, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCode.Query, + mockTx.On("Exec", queryInsertAuthorizationCode.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Code, suite.testAuthzCode.ClientID, suite.testAuthzCode.RedirectURI, suite.testAuthzCode.AuthorizedUserID, suite.testAuthzCode.TimeCreated, suite.testAuthzCode.ExpiryTime, suite.testAuthzCode.State, suite.testAuthzCode.CodeChallenge, suite.testAuthzCode.CodeChallengeMethod). Return(nil, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCodeScopes.Query, + mockTx.On("Exec", queryInsertAuthorizationCodeScopes.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Scopes). Return(nil, nil) @@ -111,40 +109,40 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_Succes err := suite.store.InsertAuthorizationCode(suite.testAuthzCode) assert.NoError(suite.T(), err) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) mockTx.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_DBClientError() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) err := suite.store.InsertAuthorizationCode(suite.testAuthzCode) assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "db client error") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_BeginTxError() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) suite.mockDBClient.On("BeginTx").Return(nil, errors.New("tx error")) err := suite.store.InsertAuthorizationCode(suite.testAuthzCode) assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "failed to begin transaction") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ExecError() { mockTx := &modelmock.TxInterfaceMock{} - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) suite.mockDBClient.On("BeginTx").Return(mockTx, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCode.Query, + mockTx.On("Exec", queryInsertAuthorizationCode.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Code, suite.testAuthzCode.ClientID, suite.testAuthzCode.RedirectURI, suite.testAuthzCode.AuthorizedUserID, suite.testAuthzCode.TimeCreated, suite.testAuthzCode.ExpiryTime, suite.testAuthzCode.State, @@ -157,7 +155,7 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ExecEr assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "failed to insert authorization code") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) mockTx.AssertExpectations(suite.T()) } @@ -165,17 +163,17 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ExecEr func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ScopeExecError() { mockTx := &modelmock.TxInterfaceMock{} - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) suite.mockDBClient.On("BeginTx").Return(mockTx, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCode.Query, + mockTx.On("Exec", queryInsertAuthorizationCode.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Code, suite.testAuthzCode.ClientID, suite.testAuthzCode.RedirectURI, suite.testAuthzCode.AuthorizedUserID, suite.testAuthzCode.TimeCreated, suite.testAuthzCode.ExpiryTime, suite.testAuthzCode.State, suite.testAuthzCode.CodeChallenge, suite.testAuthzCode.CodeChallengeMethod). Return(nil, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCodeScopes.Query, + mockTx.On("Exec", queryInsertAuthorizationCodeScopes.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Scopes). Return(nil, errors.New("scope exec error")) @@ -185,7 +183,7 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ScopeE assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "failed to insert authorization code scopes") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) mockTx.AssertExpectations(suite.T()) } @@ -193,17 +191,17 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_ScopeE func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_CommitError() { mockTx := &modelmock.TxInterfaceMock{} - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) suite.mockDBClient.On("BeginTx").Return(mockTx, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCode.Query, + mockTx.On("Exec", queryInsertAuthorizationCode.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Code, suite.testAuthzCode.ClientID, suite.testAuthzCode.RedirectURI, suite.testAuthzCode.AuthorizedUserID, suite.testAuthzCode.TimeCreated, suite.testAuthzCode.ExpiryTime, suite.testAuthzCode.State, suite.testAuthzCode.CodeChallenge, suite.testAuthzCode.CodeChallengeMethod). Return(nil, nil) - mockTx.On("Exec", constants.QueryInsertAuthorizationCodeScopes.Query, + mockTx.On("Exec", queryInsertAuthorizationCodeScopes.Query, suite.testAuthzCode.CodeID, suite.testAuthzCode.Scopes). Return(nil, nil) @@ -213,7 +211,7 @@ func (suite *AuthorizationCodeStoreTestSuite) TestInsertAuthorizationCode_Commit assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "failed to commit transaction") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) mockTx.AssertExpectations(suite.T()) } @@ -230,7 +228,7 @@ func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_Success() "authz_user": "test-user-id", "time_created": testTimeStr, "expiry_time": testTimeStr, - "state": constants.AuthCodeStateActive, + "state": AuthCodeStateActive, }, } @@ -238,10 +236,10 @@ func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_Success() {"scope": "read write"}, } - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Query", constants.QueryGetAuthorizationCode, "test-client-id", "test-code"). + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetAuthorizationCode, "test-client-id", "test-code"). Return(queryResults, nil) - suite.mockDBClient.On("Query", constants.QueryGetAuthorizationCodeScopes, "test-code-id"). + suite.mockDBClient.On("Query", queryGetAuthorizationCodeScopes, "test-code-id"). Return(scopeResults, nil) result, err := suite.store.GetAuthorizationCode("test-client-id", "test-code") @@ -254,49 +252,49 @@ func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_Success() assert.Equal(suite.T(), testTime, result.TimeCreated) assert.Equal(suite.T(), testTime, result.ExpiryTime) assert.Equal(suite.T(), "read write", result.Scopes) - assert.Equal(suite.T(), constants.AuthCodeStateActive, result.State) + assert.Equal(suite.T(), AuthCodeStateActive, result.State) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_DBClientError() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) result, err := suite.store.GetAuthorizationCode("test-client-id", "test-code") assert.Error(suite.T(), err) - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), AuthorizationCode{}, result) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_QueryError() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Query", constants.QueryGetAuthorizationCode, "test-client-id", "test-code"). + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetAuthorizationCode, "test-client-id", "test-code"). Return(nil, errors.New("query error")) result, err := suite.store.GetAuthorizationCode("test-client-id", "test-code") assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "error while retrieving authorization code") - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), AuthorizationCode{}, result) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_NoResults() { queryResults := []map[string]interface{}{} - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Query", constants.QueryGetAuthorizationCode, "test-client-id", "test-code"). + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetAuthorizationCode, "test-client-id", "test-code"). Return(queryResults, nil) result, err := suite.store.GetAuthorizationCode("test-client-id", "test-code") assert.Error(suite.T(), err) - assert.Equal(suite.T(), constants.ErrAuthorizationCodeNotFound, err) - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), ErrAuthorizationCodeNotFound, err) + assert.Equal(suite.T(), AuthorizationCode{}, result) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } @@ -307,63 +305,63 @@ func (suite *AuthorizationCodeStoreTestSuite) TestGetAuthorizationCode_EmptyCode }, } - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Query", constants.QueryGetAuthorizationCode, "test-client-id", "test-code"). + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Query", queryGetAuthorizationCode, "test-client-id", "test-code"). Return(queryResults, nil) result, err := suite.store.GetAuthorizationCode("test-client-id", "test-code") assert.Error(suite.T(), err) - assert.Equal(suite.T(), constants.ErrAuthorizationCodeNotFound, err) - assert.Equal(suite.T(), model.AuthorizationCode{}, result) + assert.Equal(suite.T(), ErrAuthorizationCodeNotFound, err) + assert.Equal(suite.T(), AuthorizationCode{}, result) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestDeactivateAuthorizationCode_Success() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Execute", constants.QueryUpdateAuthorizationCodeState, - constants.AuthCodeStateInactive, suite.testAuthzCode.CodeID).Return(int64(1), nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Execute", queryUpdateAuthorizationCodeState, + AuthCodeStateInactive, suite.testAuthzCode.CodeID).Return(int64(1), nil) err := suite.store.DeactivateAuthorizationCode(suite.testAuthzCode) assert.NoError(suite.T(), err) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestRevokeAuthorizationCode_Success() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Execute", constants.QueryUpdateAuthorizationCodeState, - constants.AuthCodeStateRevoked, suite.testAuthzCode.CodeID).Return(int64(1), nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Execute", queryUpdateAuthorizationCodeState, + AuthCodeStateRevoked, suite.testAuthzCode.CodeID).Return(int64(1), nil) err := suite.store.RevokeAuthorizationCode(suite.testAuthzCode) assert.NoError(suite.T(), err) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestExpireAuthorizationCode_Success() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) - suite.mockDBClient.On("Execute", constants.QueryUpdateAuthorizationCodeState, - constants.AuthCodeStateExpired, suite.testAuthzCode.CodeID).Return(int64(1), nil) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(suite.mockDBClient, nil) + suite.mockDBClient.On("Execute", queryUpdateAuthorizationCodeState, + AuthCodeStateExpired, suite.testAuthzCode.CodeID).Return(int64(1), nil) err := suite.store.ExpireAuthorizationCode(suite.testAuthzCode) assert.NoError(suite.T(), err) - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) suite.mockDBClient.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestUpdateAuthorizationCodeState_Error() { - suite.mockDBProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) + suite.mockdbProvider.On("GetDBClient", "runtime").Return(nil, errors.New("db client error")) err := suite.store.DeactivateAuthorizationCode(suite.testAuthzCode) assert.Error(suite.T(), err) assert.Contains(suite.T(), err.Error(), "db client error") - suite.mockDBProvider.AssertExpectations(suite.T()) + suite.mockdbProvider.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeStoreTestSuite) TestParseTimeField_StringInput() { diff --git a/backend/internal/oauth/oauth2/authz/constants/dbconstants.go b/backend/internal/oauth/oauth2/authz/storeconstants.go similarity index 73% rename from backend/internal/oauth/oauth2/authz/constants/dbconstants.go rename to backend/internal/oauth/oauth2/authz/storeconstants.go index 5d5dabc6..6151c0bb 100644 --- a/backend/internal/oauth/oauth2/authz/constants/dbconstants.go +++ b/backend/internal/oauth/oauth2/authz/storeconstants.go @@ -16,40 +16,40 @@ * under the License. */ -package constants +package authz import dbmodel "github.com/asgardeo/thunder/internal/system/database/model" -// QueryInsertAuthorizationCode is the query to insert a new authorization code into the database. -var QueryInsertAuthorizationCode = dbmodel.DBQuery{ +// queryInsertAuthorizationCode is the query to insert a new authorization code into the database. +var queryInsertAuthorizationCode = dbmodel.DBQuery{ ID: "AZQ-00001", Query: "INSERT INTO IDN_OAUTH2_AUTHZ_CODE (CODE_ID, AUTHORIZATION_CODE, CONSUMER_KEY, " + "CALLBACK_URL, AUTHZ_USER, TIME_CREATED, EXPIRY_TIME, STATE, CODE_CHALLENGE, CODE_CHALLENGE_METHOD)" + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)", } -// QueryInsertAuthorizationCodeScopes is the query to insert scopes for an authorization code. -var QueryInsertAuthorizationCodeScopes = dbmodel.DBQuery{ +// queryInsertAuthorizationCodeScopes is the query to insert scopes for an authorization code. +var queryInsertAuthorizationCodeScopes = dbmodel.DBQuery{ ID: "AZQ-00002", Query: "INSERT INTO IDN_OAUTH2_AUTHZ_CODE_SCOPE (CODE_ID, SCOPE) VALUES ($1, $2)", } -// QueryGetAuthorizationCode is the query to retrieve an authorization code by client ID and code. -var QueryGetAuthorizationCode = dbmodel.DBQuery{ +// queryGetAuthorizationCode is the query to retrieve an authorization code by client ID and code. +var queryGetAuthorizationCode = dbmodel.DBQuery{ ID: "AZQ-00003", Query: "SELECT CODE_ID, AUTHORIZATION_CODE, CALLBACK_URL, AUTHZ_USER, TIME_CREATED, " + "EXPIRY_TIME, STATE, CODE_CHALLENGE, CODE_CHALLENGE_METHOD FROM IDN_OAUTH2_AUTHZ_CODE WHERE " + "CONSUMER_KEY = $1 AND AUTHORIZATION_CODE = $2", } -// QueryUpdateAuthorizationCodeState is the query to update the state of an authorization code. -var QueryUpdateAuthorizationCodeState = dbmodel.DBQuery{ +// queryUpdateAuthorizationCodeState is the query to update the state of an authorization code. +var queryUpdateAuthorizationCodeState = dbmodel.DBQuery{ ID: "AZQ-00004", Query: "UPDATE IDN_OAUTH2_AUTHZ_CODE SET STATE = $1 WHERE CODE_ID = $2", } -// QueryGetAuthorizationCodeScopes is the query to retrieve scopes for an authorization code. -var QueryGetAuthorizationCodeScopes = dbmodel.DBQuery{ +// queryGetAuthorizationCodeScopes is the query to retrieve scopes for an authorization code. +var queryGetAuthorizationCodeScopes = dbmodel.DBQuery{ ID: "AZQ-00005", Query: "SELECT SCOPE FROM IDN_OAUTH2_AUTHZ_CODE_SCOPE WHERE CODE_ID = $1", } diff --git a/backend/internal/oauth/oauth2/authz/authzvalidator.go b/backend/internal/oauth/oauth2/authz/validator.go similarity index 86% rename from backend/internal/oauth/oauth2/authz/authzvalidator.go rename to backend/internal/oauth/oauth2/authz/validator.go index d7a87c43..0e64f056 100644 --- a/backend/internal/oauth/oauth2/authz/authzvalidator.go +++ b/backend/internal/oauth/oauth2/authz/validator.go @@ -20,7 +20,6 @@ package authz import ( appmodel "github.com/asgardeo/thunder/internal/application/model" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" "github.com/asgardeo/thunder/internal/oauth/oauth2/pkce" "github.com/asgardeo/thunder/internal/system/log" @@ -28,20 +27,20 @@ import ( // AuthorizationValidatorInterface defines the interface for validating OAuth2 authorization requests. type AuthorizationValidatorInterface interface { - validateInitialAuthorizationRequest(msg *model.OAuthMessage, oauthApp *appmodel.OAuthAppConfigProcessedDTO) ( + validateInitialAuthorizationRequest(msg *OAuthMessage, oauthApp *appmodel.OAuthAppConfigProcessedDTO) ( bool, string, string) } -// AuthorizationValidator implements the AuthorizationValidatorInterface for validating OAuth2 authorization requests. -type AuthorizationValidator struct{} +// authorizationValidator implements the AuthorizationValidatorInterface for validating OAuth2 authorization requests. +type authorizationValidator struct{} -// NewAuthorizationValidator creates a new instance of AuthorizationValidator. -func NewAuthorizationValidator() AuthorizationValidatorInterface { - return &AuthorizationValidator{} +// newAuthorizationValidator creates a new instance of authorizationValidator. +func newAuthorizationValidator() AuthorizationValidatorInterface { + return &authorizationValidator{} } // validateInitialAuthorizationRequest validates the initial authorization request parameters. -func (av *AuthorizationValidator) validateInitialAuthorizationRequest(msg *model.OAuthMessage, +func (av *authorizationValidator) validateInitialAuthorizationRequest(msg *OAuthMessage, oauthApp *appmodel.OAuthAppConfigProcessedDTO) (bool, string, string) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "AuthorizationValidator")) diff --git a/backend/internal/oauth/oauth2/authz/authzvalidator_test.go b/backend/internal/oauth/oauth2/authz/validator_test.go similarity index 94% rename from backend/internal/oauth/oauth2/authz/authzvalidator_test.go rename to backend/internal/oauth/oauth2/authz/validator_test.go index d49ab674..3bdf331b 100644 --- a/backend/internal/oauth/oauth2/authz/authzvalidator_test.go +++ b/backend/internal/oauth/oauth2/authz/validator_test.go @@ -25,7 +25,7 @@ import ( "github.com/stretchr/testify/suite" appmodel "github.com/asgardeo/thunder/internal/application/model" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" ) @@ -40,7 +40,7 @@ func TestAuthorizationValidatorTestSuite(t *testing.T) { } func (suite *AuthorizationValidatorTestSuite) SetupTest() { - suite.validator = NewAuthorizationValidator() + suite.validator = newAuthorizationValidator() suite.oauthApp = &appmodel.OAuthAppConfigProcessedDTO{ ClientID: "test-client-id", @@ -53,14 +53,14 @@ func (suite *AuthorizationValidatorTestSuite) SetupTest() { } } -func (suite *AuthorizationValidatorTestSuite) TestNewAuthorizationValidator() { - validator := NewAuthorizationValidator() +func (suite *AuthorizationValidatorTestSuite) TestnewAuthorizationValidator() { + validator := newAuthorizationValidator() assert.NotNil(suite.T(), validator) assert.Implements(suite.T(), (*AuthorizationValidatorInterface)(nil), validator) } func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRequest_Success() { - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "https://client.example.com/callback", @@ -77,7 +77,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRe } func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRequest_MissingClientID() { - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamRedirectURI: "https://client.example.com/callback", constants.RequestParamResponseType: string(constants.ResponseTypeCode), @@ -93,7 +93,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRe } func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRequest_InvalidRedirectURI() { - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "https://malicious.example.com/callback", // not in allowed list @@ -121,7 +121,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateAuthzRequest_CodeGrant constants.TokenEndpointAuthMethodClientSecretPost}, } - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "https://client.example.com/callback", @@ -138,7 +138,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateAuthzRequest_CodeGrant } func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRequest_MissingResponseType() { - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "https://client.example.com/callback", @@ -165,7 +165,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRe constants.TokenEndpointAuthMethodClientSecretPost}, } - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "https://client.example.com/callback", @@ -182,7 +182,7 @@ func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRe } func (suite *AuthorizationValidatorTestSuite) TestValidateInitialAuthorizationRequest_EmptyRedirectURI() { - msg := &model.OAuthMessage{ + msg := &OAuthMessage{ RequestQueryParams: map[string]string{ constants.RequestParamClientID: "test-client-id", constants.RequestParamRedirectURI: "", // empty redirect URI should be OK if app has only one registered diff --git a/backend/internal/oauth/oauth2/granthandlers/authorizationcode.go b/backend/internal/oauth/oauth2/granthandlers/authorizationcode.go index 4bb7fc02..f4335fd2 100644 --- a/backend/internal/oauth/oauth2/granthandlers/authorizationcode.go +++ b/backend/internal/oauth/oauth2/granthandlers/authorizationcode.go @@ -25,9 +25,7 @@ import ( "time" appmodel "github.com/asgardeo/thunder/internal/application/model" - authzconstants "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - authzmodel "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/store" + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" "github.com/asgardeo/thunder/internal/oauth/oauth2/model" "github.com/asgardeo/thunder/internal/oauth/oauth2/pkce" @@ -47,17 +45,21 @@ const ( // authorizationCodeGrantHandler handles the authorization code grant type. type authorizationCodeGrantHandler struct { - JWTService jwt.JWTServiceInterface - AuthZStore store.AuthorizationCodeStoreInterface - UserService user.UserServiceInterface + jwtService jwt.JWTServiceInterface + authzService authz.AuthorizeServiceInterface + userService user.UserServiceInterface } // newAuthorizationCodeGrantHandler creates a new instance of AuthorizationCodeGrantHandler. -func newAuthorizationCodeGrantHandler() GrantHandlerInterface { +func newAuthorizationCodeGrantHandler( + jwtService jwt.JWTServiceInterface, + userService user.UserServiceInterface, + authzService authz.AuthorizeServiceInterface, +) GrantHandlerInterface { return &authorizationCodeGrantHandler{ - JWTService: jwt.GetJWTService(), - AuthZStore: store.NewAuthorizationCodeStore(), - UserService: user.GetUserService(), + jwtService: jwtService, + authzService: authzService, + userService: userService, } } @@ -136,7 +138,7 @@ func (h *authorizationCodeGrantHandler) HandleGrant(tokenRequest *model.TokenReq // Generate access token iss, validityPeriod := resolveTokenConfig(oauthApp) - token, _, err := h.JWTService.GenerateJWT(authCode.AuthorizedUserID, authCode.ClientID, + token, _, err := h.jwtService.GenerateJWT(authCode.AuthorizedUserID, authCode.ClientID, iss, validityPeriod, jwtClaims) if err != nil { return nil, &model.ErrorResponse{ @@ -168,25 +170,25 @@ func (h *authorizationCodeGrantHandler) retrieveAndValidateAuthCode( tokenRequest *model.TokenRequest, oauthApp *appmodel.OAuthAppConfigProcessedDTO, logger *log.Logger, -) (authzmodel.AuthorizationCode, *model.ErrorResponse) { - authCode, err := h.AuthZStore.GetAuthorizationCode(tokenRequest.ClientID, tokenRequest.Code) - if err != nil || authCode.Code == "" { - return authzmodel.AuthorizationCode{}, &model.ErrorResponse{ +) (*authz.AuthorizationCode, *model.ErrorResponse) { + authCode, codeErr := h.authzService.GetAuthorizationCodeDetails(tokenRequest.ClientID, tokenRequest.Code) + if codeErr != nil { + return nil, &model.ErrorResponse{ Error: constants.ErrorInvalidGrant, ErrorDescription: "Invalid authorization code", } } // Validate the retrieved authorization code - errResponse := validateAuthorizationCode(tokenRequest, authCode) + errResponse := validateAuthorizationCode(tokenRequest, *authCode) if errResponse != nil && errResponse.Error != "" { - return authzmodel.AuthorizationCode{}, errResponse + return nil, errResponse } // Validate PKCE if required or if code challenge was provided during authorization if oauthApp.RequiresPKCE() || authCode.CodeChallenge != "" { if tokenRequest.CodeVerifier == "" { - return authzmodel.AuthorizationCode{}, &model.ErrorResponse{ + return nil, &model.ErrorResponse{ Error: constants.ErrorInvalidRequest, ErrorDescription: "code_verifier is required", } @@ -196,22 +198,12 @@ func (h *authorizationCodeGrantHandler) retrieveAndValidateAuthCode( if err := pkce.ValidatePKCE(authCode.CodeChallenge, authCode.CodeChallengeMethod, tokenRequest.CodeVerifier); err != nil { logger.Debug("PKCE validation failed", log.Error(err)) - return authzmodel.AuthorizationCode{}, &model.ErrorResponse{ + return nil, &model.ErrorResponse{ Error: constants.ErrorInvalidGrant, ErrorDescription: "Invalid code verifier", } } } - - // Invalidate the authorization code after use - err = h.AuthZStore.DeactivateAuthorizationCode(authCode) - if err != nil { - return authzmodel.AuthorizationCode{}, &model.ErrorResponse{ - Error: constants.ErrorServerError, - ErrorDescription: "Failed to invalidate authorization code", - } - } - return authCode, nil } @@ -241,7 +233,7 @@ func (h *authorizationCodeGrantHandler) fetchUserAttributesAndGroups( } // Fetch user attributes - user, svcErr := h.UserService.GetUser(userID) + user, svcErr := h.userService.GetUser(userID) if svcErr != nil { logger.Error("Failed to fetch user attributes", log.String("userID", userID), log.Any("error", svcErr)) return nil, nil, &model.ErrorResponse{ @@ -263,7 +255,7 @@ func (h *authorizationCodeGrantHandler) fetchUserAttributesAndGroups( (oauthApp.Token.IDToken != nil && slices.Contains(oauthApp.Token.IDToken.UserAttributes, UserAttributeGroups)) if needsGroups { - groups, svcErr := h.UserService.GetUserGroups(userID, DefaultGroupListLimit, 0) + groups, svcErr := h.userService.GetUserGroups(userID, DefaultGroupListLimit, 0) if svcErr != nil { logger.Error("Failed to fetch user groups", log.String("userID", userID), log.Any("error", svcErr)) return nil, nil, &model.ErrorResponse{ @@ -343,7 +335,7 @@ func resolveTokenConfig(oauthApp *appmodel.OAuthAppConfigProcessedDTO) (string, } // updateContextAttributes updates the token context with subject and audience. -func updateContextAttributes(ctx *model.TokenContext, authCode authzmodel.AuthorizationCode) { +func updateContextAttributes(ctx *model.TokenContext, authCode *authz.AuthorizationCode) { if ctx.TokenAttributes == nil { ctx.TokenAttributes = make(map[string]interface{}) } @@ -375,7 +367,7 @@ func buildTokenResponse( // validateAuthorizationCode validates the authorization code against the token request. func validateAuthorizationCode(tokenRequest *model.TokenRequest, - code authzmodel.AuthorizationCode) *model.ErrorResponse { + code authz.AuthorizationCode) *model.ErrorResponse { if tokenRequest.ClientID != code.ClientID { return &model.ErrorResponse{ Error: constants.ErrorInvalidClient, @@ -391,14 +383,14 @@ func validateAuthorizationCode(tokenRequest *model.TokenRequest, } } - if code.State == authzconstants.AuthCodeStateInactive { + if code.State == authz.AuthCodeStateInactive { // TODO: Revoke all the tokens issued for this authorization code. return &model.ErrorResponse{ Error: constants.ErrorInvalidGrant, ErrorDescription: "Inactive authorization code", } - } else if code.State != authzconstants.AuthCodeStateActive { + } else if code.State != authz.AuthCodeStateActive { return &model.ErrorResponse{ Error: constants.ErrorInvalidGrant, ErrorDescription: "Inactive authorization code", @@ -458,7 +450,7 @@ func getIDTokenClaims(scopes []string, userAttributes map[string]interface{}, } // generateIDToken generates an ID token for the given authorization code and scopes -func (h *authorizationCodeGrantHandler) generateIDToken(authCode authzmodel.AuthorizationCode, +func (h *authorizationCodeGrantHandler) generateIDToken(authCode *authz.AuthorizationCode, tokenRequest *model.TokenRequest, authorizedScopes []string, attrs map[string]interface{}, oauthApp *appmodel.OAuthAppConfigProcessedDTO) (*model.TokenDTO, *model.ErrorResponse) { idTokenClaims := getIDTokenClaims(authorizedScopes, attrs, oauthApp) @@ -483,7 +475,7 @@ func (h *authorizationCodeGrantHandler) generateIDToken(authCode authzmodel.Auth } // Generate ID token JWT - idToken, _, err := h.JWTService.GenerateJWT(authCode.AuthorizedUserID, authCode.ClientID, + idToken, _, err := h.jwtService.GenerateJWT(authCode.AuthorizedUserID, authCode.ClientID, idTokenIss, idTokenValidityPeriod, idTokenClaims) if err != nil { return nil, &model.ErrorResponse{ diff --git a/backend/internal/oauth/oauth2/granthandlers/authorizationcode_test.go b/backend/internal/oauth/oauth2/granthandlers/authorizationcode_test.go index 8fa89ec3..094b8fee 100644 --- a/backend/internal/oauth/oauth2/granthandlers/authorizationcode_test.go +++ b/backend/internal/oauth/oauth2/granthandlers/authorizationcode_test.go @@ -29,26 +29,25 @@ import ( "github.com/stretchr/testify/suite" appmodel "github.com/asgardeo/thunder/internal/application/model" - authzconstants "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/constants" - authzmodel "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" "github.com/asgardeo/thunder/internal/oauth/oauth2/model" "github.com/asgardeo/thunder/internal/system/config" "github.com/asgardeo/thunder/internal/user" "github.com/asgardeo/thunder/tests/mocks/jwtmock" - "github.com/asgardeo/thunder/tests/mocks/oauth/oauth2/authz/storemock" + "github.com/asgardeo/thunder/tests/mocks/oauth/oauth2/authzmock" usersvcmock "github.com/asgardeo/thunder/tests/mocks/usermock" ) type AuthorizationCodeGrantHandlerTestSuite struct { suite.Suite - handler *authorizationCodeGrantHandler - mockJWTService *jwtmock.JWTServiceInterfaceMock - mockAuthZStore *storemock.AuthorizationCodeStoreInterfaceMock - mockUserService *usersvcmock.UserServiceInterfaceMock - oauthApp *appmodel.OAuthAppConfigProcessedDTO - testAuthzCode authzmodel.AuthorizationCode - testTokenReq *model.TokenRequest + handler *authorizationCodeGrantHandler + mockJWTService *jwtmock.JWTServiceInterfaceMock + mockAuthzService *authzmock.AuthorizeServiceInterfaceMock + mockUserService *usersvcmock.UserServiceInterfaceMock + oauthApp *appmodel.OAuthAppConfigProcessedDTO + testAuthzCode authz.AuthorizationCode + testTokenReq *model.TokenRequest } func TestAuthorizationCodeGrantHandlerSuite(t *testing.T) { @@ -65,13 +64,13 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) SetupTest() { _ = config.InitializeThunderRuntime("test", testConfig) suite.mockJWTService = &jwtmock.JWTServiceInterfaceMock{} - suite.mockAuthZStore = &storemock.AuthorizationCodeStoreInterfaceMock{} + suite.mockAuthzService = &authzmock.AuthorizeServiceInterfaceMock{} suite.mockUserService = usersvcmock.NewUserServiceInterfaceMock(suite.T()) suite.handler = &authorizationCodeGrantHandler{ - JWTService: suite.mockJWTService, - AuthZStore: suite.mockAuthZStore, - UserService: suite.mockUserService, + jwtService: suite.mockJWTService, + authzService: suite.mockAuthzService, + userService: suite.mockUserService, } suite.oauthApp = &appmodel.OAuthAppConfigProcessedDTO{ @@ -96,7 +95,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) SetupTest() { RedirectURI: "https://client.example.com/callback", } - suite.testAuthzCode = authzmodel.AuthorizationCode{ + suite.testAuthzCode = authz.AuthorizationCode{ CodeID: "test-code-id", Code: "test-auth-code", ClientID: "test-client-id", @@ -105,12 +104,12 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) SetupTest() { TimeCreated: time.Now().Add(-5 * time.Minute), ExpiryTime: time.Now().Add(5 * time.Minute), Scopes: "read write", - State: authzconstants.AuthCodeStateActive, + State: authz.AuthCodeStateActive, } } func (suite *AuthorizationCodeGrantHandlerTestSuite) TestNewAuthorizationCodeGrantHandler() { - handler := newAuthorizationCodeGrantHandler() + handler := newAuthorizationCodeGrantHandler(suite.mockJWTService, suite.mockUserService, suite.mockAuthzService) assert.NotNil(suite.T(), handler) assert.Implements(suite.T(), (*GrantHandlerInterface)(nil), handler) } @@ -188,9 +187,8 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestValidateGrant_MissingRe func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_Success() { // Mock authorization code store to return valid code - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(suite.testAuthzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", suite.testAuthzCode).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&suite.testAuthzCode, nil) // Mock user service to return user for attributes mockUser := &user.User{ @@ -222,14 +220,14 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_Success() { assert.Equal(suite.T(), "test-user-id", ctx.TokenAttributes["sub"]) assert.Equal(suite.T(), "test-client-id", ctx.TokenAttributes["aud"]) - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_InvalidAuthorizationCode() { // Mock authorization code store to return error - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(authzmodel.AuthorizationCode{}, errors.New("code not found")) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(nil, errors.New("invalid authorization code")) ctx := &model.TokenContext{ TokenAttributes: make(map[string]interface{}), @@ -242,55 +240,13 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_InvalidAuth assert.Equal(suite.T(), constants.ErrorInvalidGrant, err.Error) assert.Equal(suite.T(), "Invalid authorization code", err.ErrorDescription) - suite.mockAuthZStore.AssertExpectations(suite.T()) -} - -func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_EmptyAuthorizationCode() { - // Mock authorization code store to return empty code - emptyCode := authzmodel.AuthorizationCode{Code: ""} - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(emptyCode, nil) - - ctx := &model.TokenContext{ - TokenAttributes: make(map[string]interface{}), - } - - result, err := suite.handler.HandleGrant(suite.testTokenReq, suite.oauthApp, ctx) - - assert.Nil(suite.T(), result) - assert.NotNil(suite.T(), err) - assert.Equal(suite.T(), constants.ErrorInvalidGrant, err.Error) - assert.Equal(suite.T(), "Invalid authorization code", err.ErrorDescription) - - suite.mockAuthZStore.AssertExpectations(suite.T()) -} - -func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_DeactivateError() { - // Mock authorization code store to return valid code but fail deactivation - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(suite.testAuthzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", suite.testAuthzCode). - Return(errors.New("deactivate failed")) - - ctx := &model.TokenContext{ - TokenAttributes: make(map[string]interface{}), - } - - result, err := suite.handler.HandleGrant(suite.testTokenReq, suite.oauthApp, ctx) - - assert.Nil(suite.T(), result) - assert.NotNil(suite.T(), err) - assert.Equal(suite.T(), constants.ErrorServerError, err.Error) - assert.Equal(suite.T(), "Failed to invalidate authorization code", err.ErrorDescription) - - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_JWTGenerationError() { // Mock authorization code store to return valid code - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(suite.testAuthzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", suite.testAuthzCode).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&suite.testAuthzCode, nil) // Mock user service to return user for attributes mockUser := &user.User{ @@ -315,7 +271,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_JWTGenerati assert.Equal(suite.T(), constants.ErrorServerError, err.Error) assert.Equal(suite.T(), "Failed to generate token", err.ErrorDescription) - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) } @@ -324,9 +280,8 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_EmptyScopes authzCodeWithEmptyScopes := suite.testAuthzCode authzCodeWithEmptyScopes.Scopes = "" - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(authzCodeWithEmptyScopes, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", authzCodeWithEmptyScopes).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&authzCodeWithEmptyScopes, nil) // Mock user service to return user for attributes mockUser := &user.User{ @@ -349,15 +304,14 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_EmptyScopes assert.NotNil(suite.T(), result) assert.Empty(suite.T(), result.AccessToken.Scopes) - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) } func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_NilTokenAttributes() { // Test with nil token attributes - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(suite.testAuthzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", suite.testAuthzCode).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&suite.testAuthzCode, nil) // Mock user service to return user for attributes mockUser := &user.User{ @@ -384,7 +338,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_NilTokenAtt assert.Equal(suite.T(), "test-user-id", ctx.TokenAttributes["sub"]) assert.Equal(suite.T(), "test-client-id", ctx.TokenAttributes["aud"]) - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) } @@ -432,7 +386,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestValidateAuthorizationCo func (suite *AuthorizationCodeGrantHandlerTestSuite) TestValidateAuthorizationCode_InactiveCode() { inactiveCode := suite.testAuthzCode - inactiveCode.State = authzconstants.AuthCodeStateInactive + inactiveCode.State = authz.AuthCodeStateInactive err := validateAuthorizationCode(suite.testTokenReq, inactiveCode) assert.NotNil(suite.T(), err) @@ -503,12 +457,14 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithGroups( for _, tc := range testCases { suite.Run(tc.name, func() { // Reset mocks for each test case - suite.mockAuthZStore = &storemock.AuthorizationCodeStoreInterfaceMock{} + suite.mockAuthzService = &authzmock.AuthorizeServiceInterfaceMock{} suite.mockUserService = usersvcmock.NewUserServiceInterfaceMock(suite.T()) suite.mockJWTService = &jwtmock.JWTServiceInterfaceMock{} - suite.handler.AuthZStore = suite.mockAuthZStore - suite.handler.UserService = suite.mockUserService - suite.handler.JWTService = suite.mockJWTService + suite.handler = &authorizationCodeGrantHandler{ + jwtService: suite.mockJWTService, + authzService: suite.mockAuthzService, + userService: suite.mockUserService, + } accessTokenAttrs := []string{"email", "username"} if tc.includeInAccessToken { @@ -552,9 +508,8 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithGroups( authzCode.Scopes = "openid read write" } - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(authzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", authzCode).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&authzCode, nil) mockUser := &user.User{ ID: "test-user-id", @@ -632,7 +587,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithGroups( assert.Empty(suite.T(), result.IDToken.Token, tc.description) } - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockUserService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) }) @@ -668,12 +623,14 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithEmptyGr for _, tc := range testCases { suite.Run(tc.name, func() { - suite.mockAuthZStore = &storemock.AuthorizationCodeStoreInterfaceMock{} + suite.mockAuthzService = &authzmock.AuthorizeServiceInterfaceMock{} suite.mockUserService = usersvcmock.NewUserServiceInterfaceMock(suite.T()) suite.mockJWTService = &jwtmock.JWTServiceInterfaceMock{} - suite.handler.AuthZStore = suite.mockAuthZStore - suite.handler.UserService = suite.mockUserService - suite.handler.JWTService = suite.mockJWTService + suite.handler = &authorizationCodeGrantHandler{ + jwtService: suite.mockJWTService, + authzService: suite.mockAuthzService, + userService: suite.mockUserService, + } accessTokenAttrs := []string{"email", "username"} if tc.includeInAccessToken { @@ -716,9 +673,8 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithEmptyGr authzCode.Scopes = "openid read write" } - suite.mockAuthZStore.On("GetAuthorizationCode", "test-client-id", "test-auth-code"). - Return(authzCode, nil) - suite.mockAuthZStore.On("DeactivateAuthorizationCode", authzCode).Return(nil) + suite.mockAuthzService.On("GetAuthorizationCodeDetails", "test-client-id", "test-auth-code"). + Return(&authzCode, nil) mockUser := &user.User{ ID: "test-user-id", @@ -778,7 +734,7 @@ func (suite *AuthorizationCodeGrantHandlerTestSuite) TestHandleGrant_WithEmptyGr assert.Empty(suite.T(), result.IDToken.Token, tc.description) } - suite.mockAuthZStore.AssertExpectations(suite.T()) + suite.mockAuthzService.AssertExpectations(suite.T()) suite.mockUserService.AssertExpectations(suite.T()) suite.mockJWTService.AssertExpectations(suite.T()) }) diff --git a/backend/internal/oauth/oauth2/granthandlers/clientcredentials.go b/backend/internal/oauth/oauth2/granthandlers/clientcredentials.go index fe47c1c8..f27ca87a 100644 --- a/backend/internal/oauth/oauth2/granthandlers/clientcredentials.go +++ b/backend/internal/oauth/oauth2/granthandlers/clientcredentials.go @@ -31,13 +31,13 @@ import ( // clientCredentialsGrantHandler handles the client credentials grant type. type clientCredentialsGrantHandler struct { - JWTService jwt.JWTServiceInterface + jwtService jwt.JWTServiceInterface } // newClientCredentialsGrantHandler creates a new instance of ClientCredentialsGrantHandler. -func newClientCredentialsGrantHandler() GrantHandlerInterface { +func newClientCredentialsGrantHandler(jwtService jwt.JWTServiceInterface) GrantHandlerInterface { return &clientCredentialsGrantHandler{ - JWTService: jwt.GetJWTService(), + jwtService: jwtService, } } @@ -92,7 +92,7 @@ func (h *clientCredentialsGrantHandler) HandleGrant(tokenRequest *model.TokenReq validityPeriod = config.GetThunderRuntime().Config.JWT.ValidityPeriod } - token, _, err := h.JWTService.GenerateJWT(tokenRequest.ClientID, tokenRequest.ClientID, iss, + token, _, err := h.jwtService.GenerateJWT(tokenRequest.ClientID, tokenRequest.ClientID, iss, validityPeriod, jwtClaims) if err != nil { return nil, &model.ErrorResponse{ diff --git a/backend/internal/oauth/oauth2/granthandlers/clientcredentials_test.go b/backend/internal/oauth/oauth2/granthandlers/clientcredentials_test.go index d65de2df..4689a034 100644 --- a/backend/internal/oauth/oauth2/granthandlers/clientcredentials_test.go +++ b/backend/internal/oauth/oauth2/granthandlers/clientcredentials_test.go @@ -61,7 +61,7 @@ func (suite *ClientCredentialsGrantHandlerTestSuite) SetupTest() { suite.mockJWTService = jwtmock.NewJWTServiceInterfaceMock(suite.T()) suite.handler = &clientCredentialsGrantHandler{ - JWTService: suite.mockJWTService, + jwtService: suite.mockJWTService, } suite.oauthApp = &appmodel.OAuthAppConfigProcessedDTO{ @@ -77,7 +77,7 @@ func (suite *ClientCredentialsGrantHandlerTestSuite) SetupTest() { } func (suite *ClientCredentialsGrantHandlerTestSuite) TestNewClientCredentialsGrantHandler() { - handler := newClientCredentialsGrantHandler() + handler := newClientCredentialsGrantHandler(suite.mockJWTService) assert.NotNil(suite.T(), handler) assert.Implements(suite.T(), (*GrantHandlerInterface)(nil), handler) } diff --git a/backend/internal/oauth/oauth2/granthandlers/init.go b/backend/internal/oauth/oauth2/granthandlers/init.go new file mode 100644 index 00000000..e4383e42 --- /dev/null +++ b/backend/internal/oauth/oauth2/granthandlers/init.go @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). + * + * WSO2 LLC. licenses this file to you under the Apache License, + * Version 2.0 (the "License"); you may not use this file except + * in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package granthandlers + +import ( + "net/http" + + "github.com/asgardeo/thunder/internal/application" + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" + "github.com/asgardeo/thunder/internal/system/jwt" + "github.com/asgardeo/thunder/internal/user" +) + +// Initialize initializes the grant handler provider with the given services. +func Initialize( + mux *http.ServeMux, + jwtService jwt.JWTServiceInterface, + userService user.UserServiceInterface, + applicationService application.ApplicationServiceInterface, +) GrantHandlerProviderInterface { + authzService := authz.Initialize(mux, applicationService, jwtService) + grantHandlerProvider := newGrantHandlerProvider(jwtService, userService, authzService) + return grantHandlerProvider +} diff --git a/backend/internal/oauth/oauth2/granthandlers/provider.go b/backend/internal/oauth/oauth2/granthandlers/provider.go index 0e9c8c0c..5a921b6e 100644 --- a/backend/internal/oauth/oauth2/granthandlers/provider.go +++ b/backend/internal/oauth/oauth2/granthandlers/provider.go @@ -19,7 +19,10 @@ package granthandlers import ( + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" + "github.com/asgardeo/thunder/internal/system/jwt" + "github.com/asgardeo/thunder/internal/user" ) // GrantHandlerProviderInterface defines the interface for the grant handler provider. @@ -28,22 +31,34 @@ type GrantHandlerProviderInterface interface { } // GrantHandlerProvider implements the GrantHandlerProviderInterface. -type GrantHandlerProvider struct{} +type GrantHandlerProvider struct { + clientCredentialsGrantHandler GrantHandlerInterface + authorizationCodeGrantHandler GrantHandlerInterface + refreshTokenGrantHandler GrantHandlerInterface +} -// NewGrantHandlerProvider creates a new instance of GrantHandlerProvider. -func NewGrantHandlerProvider() GrantHandlerProviderInterface { - return &GrantHandlerProvider{} +// newGrantHandlerProvider creates a new instance of GrantHandlerProvider. +func newGrantHandlerProvider( + jwtService jwt.JWTServiceInterface, + userService user.UserServiceInterface, + authzService authz.AuthorizeServiceInterface, +) GrantHandlerProviderInterface { + return &GrantHandlerProvider{ + clientCredentialsGrantHandler: newClientCredentialsGrantHandler(jwtService), + authorizationCodeGrantHandler: newAuthorizationCodeGrantHandler(jwtService, userService, authzService), + refreshTokenGrantHandler: newRefreshTokenGrantHandler(jwtService, userService), + } } // GetGrantHandler returns the appropriate grant handler for the given grant type. func (p *GrantHandlerProvider) GetGrantHandler(grantType constants.GrantType) (GrantHandlerInterface, error) { switch grantType { case constants.GrantTypeClientCredentials: - return newClientCredentialsGrantHandler(), nil + return p.clientCredentialsGrantHandler, nil case constants.GrantTypeAuthorizationCode: - return newAuthorizationCodeGrantHandler(), nil + return p.authorizationCodeGrantHandler, nil case constants.GrantTypeRefreshToken: - return newRefreshTokenGrantHandler(), nil + return p.refreshTokenGrantHandler, nil default: return nil, constants.UnSupportedGrantTypeError } diff --git a/backend/internal/oauth/oauth2/granthandlers/provider_test.go b/backend/internal/oauth/oauth2/granthandlers/provider_test.go index 43f60d45..754731ab 100644 --- a/backend/internal/oauth/oauth2/granthandlers/provider_test.go +++ b/backend/internal/oauth/oauth2/granthandlers/provider_test.go @@ -25,11 +25,17 @@ import ( "github.com/stretchr/testify/suite" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" + "github.com/asgardeo/thunder/tests/mocks/jwtmock" + "github.com/asgardeo/thunder/tests/mocks/oauth/oauth2/authzmock" + usersvcmock "github.com/asgardeo/thunder/tests/mocks/usermock" ) type GrantHandlerProviderTestSuite struct { suite.Suite - provider GrantHandlerProviderInterface + provider GrantHandlerProviderInterface + mockJWTService *jwtmock.JWTServiceInterfaceMock + mockUserService *usersvcmock.UserServiceInterfaceMock + authzService *authzmock.AuthorizeServiceInterfaceMock } func TestGrantHandlerProviderSuite(t *testing.T) { @@ -37,11 +43,14 @@ func TestGrantHandlerProviderSuite(t *testing.T) { } func (suite *GrantHandlerProviderTestSuite) SetupTest() { - suite.provider = NewGrantHandlerProvider() + suite.mockJWTService = jwtmock.NewJWTServiceInterfaceMock(suite.T()) + suite.mockUserService = usersvcmock.NewUserServiceInterfaceMock(suite.T()) + suite.authzService = authzmock.NewAuthorizeServiceInterfaceMock(suite.T()) + suite.provider = newGrantHandlerProvider(suite.mockJWTService, suite.mockUserService, suite.authzService) } func (suite *GrantHandlerProviderTestSuite) TestNewGrantHandlerProvider() { - provider := NewGrantHandlerProvider() + provider := newGrantHandlerProvider(suite.mockJWTService, suite.mockUserService, suite.authzService) assert.NotNil(suite.T(), provider) assert.Implements(suite.T(), (*GrantHandlerProviderInterface)(nil), provider) } diff --git a/backend/internal/oauth/oauth2/granthandlers/refreshtoken.go b/backend/internal/oauth/oauth2/granthandlers/refreshtoken.go index 2ee8b4f5..aeef9e63 100644 --- a/backend/internal/oauth/oauth2/granthandlers/refreshtoken.go +++ b/backend/internal/oauth/oauth2/granthandlers/refreshtoken.go @@ -34,15 +34,18 @@ import ( // refreshTokenGrantHandler handles the refresh token grant type. type refreshTokenGrantHandler struct { - JWTService jwt.JWTServiceInterface - UserService user.UserServiceInterface + jwtService jwt.JWTServiceInterface + userService user.UserServiceInterface } // newRefreshTokenGrantHandler creates a new instance of RefreshTokenGrantHandler. -func newRefreshTokenGrantHandler() RefreshTokenGrantHandlerInterface { +func newRefreshTokenGrantHandler( + jwtService jwt.JWTServiceInterface, + userService user.UserServiceInterface, +) RefreshTokenGrantHandlerInterface { return &refreshTokenGrantHandler{ - JWTService: jwt.GetJWTService(), - UserService: user.GetUserService(), + jwtService: jwtService, + userService: userService, } } @@ -142,7 +145,7 @@ func (h *refreshTokenGrantHandler) HandleGrant(tokenRequest *model.TokenRequest, } } - accessToken, iat, err := h.JWTService.GenerateJWT(sub, aud, iss, validityPeriod, jwtClaims) + accessToken, iat, err := h.jwtService.GenerateJWT(sub, aud, iss, validityPeriod, jwtClaims) if err != nil { return nil, &model.ErrorResponse{ Error: constants.ErrorServerError, @@ -248,7 +251,7 @@ func (h *refreshTokenGrantHandler) IssueRefreshToken(tokenResponse *model.TokenR claims["access_token_user_attributes"] = tokenResponse.AccessToken.UserAttributes } - token, iat, err := h.JWTService.GenerateJWT(oauthApp.ClientID, oauthApp.ClientID, iss, validityPeriod, claims) + token, iat, err := h.jwtService.GenerateJWT(oauthApp.ClientID, oauthApp.ClientID, iss, validityPeriod, claims) if err != nil { return &model.ErrorResponse{ Error: constants.ErrorServerError, @@ -272,7 +275,7 @@ func (h *refreshTokenGrantHandler) IssueRefreshToken(tokenResponse *model.TokenR // verifyRefreshToken verifies the refresh token using the server's public key. func (h *refreshTokenGrantHandler) verifyRefreshToken(refreshToken string, logger *log.Logger) *model.ErrorResponse { - if err := h.JWTService.VerifyJWT(refreshToken, "", ""); err != nil { + if err := h.jwtService.VerifyJWT(refreshToken, "", ""); err != nil { logger.Error("Failed to verify refresh token signature", log.Error(err)) return &model.ErrorResponse{ Error: constants.ErrorInvalidRequest, diff --git a/backend/internal/oauth/oauth2/granthandlers/refreshtoken_test.go b/backend/internal/oauth/oauth2/granthandlers/refreshtoken_test.go index e7258514..19f39054 100644 --- a/backend/internal/oauth/oauth2/granthandlers/refreshtoken_test.go +++ b/backend/internal/oauth/oauth2/granthandlers/refreshtoken_test.go @@ -73,8 +73,8 @@ func (suite *RefreshTokenGrantHandlerTestSuite) SetupTest() { suite.mockUserService = usersvcmock.NewUserServiceInterfaceMock(suite.T()) suite.handler = &refreshTokenGrantHandler{ - JWTService: suite.mockJWTService, - UserService: suite.mockUserService, + jwtService: suite.mockJWTService, + userService: suite.mockUserService, } suite.oauthApp = &appmodel.OAuthAppConfigProcessedDTO{ @@ -115,7 +115,7 @@ func (suite *RefreshTokenGrantHandlerTestSuite) TearDownTest() { } func (suite *RefreshTokenGrantHandlerTestSuite) TestNewRefreshTokenGrantHandler() { - handler := newRefreshTokenGrantHandler() + handler := newRefreshTokenGrantHandler(suite.mockJWTService, suite.mockUserService) assert.NotNil(suite.T(), handler) assert.Implements(suite.T(), (*RefreshTokenGrantHandlerInterface)(nil), handler) } diff --git a/backend/internal/oauth/oauth2/introspect/TokenIntrospectionServiceInterface_mock_test.go b/backend/internal/oauth/oauth2/introspect/TokenIntrospectionServiceInterface_mock_test.go new file mode 100644 index 00000000..3ba0b2e6 --- /dev/null +++ b/backend/internal/oauth/oauth2/introspect/TokenIntrospectionServiceInterface_mock_test.go @@ -0,0 +1,104 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package introspect + +import ( + mock "github.com/stretchr/testify/mock" +) + +// NewTokenIntrospectionServiceInterfaceMock creates a new instance of TokenIntrospectionServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewTokenIntrospectionServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *TokenIntrospectionServiceInterfaceMock { + mock := &TokenIntrospectionServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// TokenIntrospectionServiceInterfaceMock is an autogenerated mock type for the TokenIntrospectionServiceInterface type +type TokenIntrospectionServiceInterfaceMock struct { + mock.Mock +} + +type TokenIntrospectionServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *TokenIntrospectionServiceInterfaceMock) EXPECT() *TokenIntrospectionServiceInterfaceMock_Expecter { + return &TokenIntrospectionServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// IntrospectToken provides a mock function for the type TokenIntrospectionServiceInterfaceMock +func (_mock *TokenIntrospectionServiceInterfaceMock) IntrospectToken(token string, tokenTypeHint string) (*IntrospectResponse, error) { + ret := _mock.Called(token, tokenTypeHint) + + if len(ret) == 0 { + panic("no return value specified for IntrospectToken") + } + + var r0 *IntrospectResponse + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string) (*IntrospectResponse, error)); ok { + return returnFunc(token, tokenTypeHint) + } + if returnFunc, ok := ret.Get(0).(func(string, string) *IntrospectResponse); ok { + r0 = returnFunc(token, tokenTypeHint) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*IntrospectResponse) + } + } + if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { + r1 = returnFunc(token, tokenTypeHint) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'IntrospectToken' +type TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call struct { + *mock.Call +} + +// IntrospectToken is a helper method to define mock.On call +// - token string +// - tokenTypeHint string +func (_e *TokenIntrospectionServiceInterfaceMock_Expecter) IntrospectToken(token interface{}, tokenTypeHint interface{}) *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call { + return &TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call{Call: _e.mock.On("IntrospectToken", token, tokenTypeHint)} +} + +func (_c *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call) Run(run func(token string, tokenTypeHint string)) *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call) Return(introspectResponse *IntrospectResponse, err error) *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call { + _c.Call.Return(introspectResponse, err) + return _c +} + +func (_c *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call) RunAndReturn(run func(token string, tokenTypeHint string) (*IntrospectResponse, error)) *TokenIntrospectionServiceInterfaceMock_IntrospectToken_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/internal/oauth/oauth2/introspect/handler.go b/backend/internal/oauth/oauth2/introspect/handler.go index e98b1e41..b83b38b1 100644 --- a/backend/internal/oauth/oauth2/introspect/handler.go +++ b/backend/internal/oauth/oauth2/introspect/handler.go @@ -28,20 +28,20 @@ import ( "github.com/asgardeo/thunder/internal/system/log" ) -// TokenIntrospectionHandler handles OAuth 2.0 token introspection requests. -type TokenIntrospectionHandler struct { +// tokenIntrospectionHandler handles OAuth 2.0 token introspection requests. +type tokenIntrospectionHandler struct { service TokenIntrospectionServiceInterface } -// NewTokenIntrospectionHandler creates a new token introspection handler. -func NewTokenIntrospectionHandler(introspectionService TokenIntrospectionServiceInterface) *TokenIntrospectionHandler { - return &TokenIntrospectionHandler{ +// newTokenIntrospectionHandler creates a new token introspection handler (internal use). +func newTokenIntrospectionHandler(introspectionService TokenIntrospectionServiceInterface) *tokenIntrospectionHandler { + return &tokenIntrospectionHandler{ service: introspectionService, } } // HandleIntrospect handles token introspection requests -func (h *TokenIntrospectionHandler) HandleIntrospect(w http.ResponseWriter, r *http.Request) { +func (h *tokenIntrospectionHandler) HandleIntrospect(w http.ResponseWriter, r *http.Request) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "TokenIntrospectionHandler")) if err := r.ParseForm(); err != nil { diff --git a/backend/internal/oauth/oauth2/introspect/handler_test.go b/backend/internal/oauth/oauth2/introspect/handler_test.go index 45732e0c..4447d3dc 100644 --- a/backend/internal/oauth/oauth2/introspect/handler_test.go +++ b/backend/internal/oauth/oauth2/introspect/handler_test.go @@ -16,7 +16,7 @@ * under the License. */ -package introspect_test +package introspect import ( "errors" @@ -27,8 +27,6 @@ import ( "testing" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/introspect" - "github.com/asgardeo/thunder/tests/mocks/oauth/oauth2/introspectmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" @@ -36,8 +34,8 @@ import ( type TokenIntrospectionHandlerTestSuite struct { suite.Suite - introspectionServiceMock *introspectmock.TokenIntrospectionServiceInterfaceMock - handler *introspect.TokenIntrospectionHandler + introspectionServiceMock *TokenIntrospectionServiceInterfaceMock + handler *tokenIntrospectionHandler } func TestTokenIntrospectionHandlerTestSuite(t *testing.T) { @@ -45,8 +43,8 @@ func TestTokenIntrospectionHandlerTestSuite(t *testing.T) { } func (s *TokenIntrospectionHandlerTestSuite) SetupTest() { - s.introspectionServiceMock = introspectmock.NewTokenIntrospectionServiceInterfaceMock(s.T()) - s.handler = introspect.NewTokenIntrospectionHandler(s.introspectionServiceMock) + s.introspectionServiceMock = NewTokenIntrospectionServiceInterfaceMock(s.T()) + s.handler = newTokenIntrospectionHandler(s.introspectionServiceMock) } func (s *TokenIntrospectionHandlerTestSuite) TestHandleIntrospect_ParseFormError() { @@ -98,7 +96,7 @@ func (s *TokenIntrospectionHandlerTestSuite) TestHandleIntrospect_Success_Active req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Setup the mock to return a valid active token response - activeResponse := &introspect.IntrospectResponse{ + activeResponse := &IntrospectResponse{ Active: true, Scope: "openid profile", ClientID: "client123", @@ -147,7 +145,7 @@ func (s *TokenIntrospectionHandlerTestSuite) TestHandleIntrospect_Success_Encode req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Setup the mock to return a valid active token response - activeResponse := &introspect.IntrospectResponse{ + activeResponse := &IntrospectResponse{ Active: true, } s.introspectionServiceMock.On("IntrospectToken", "valid-token", "").Return(activeResponse, nil) @@ -166,7 +164,7 @@ func (s *TokenIntrospectionHandlerTestSuite) TestHandleIntrospect_Success_Inacti req.Header.Set("Content-Type", "application/x-www-form-urlencoded") // Setup the mock to return an inactive token response - inactiveResponse := &introspect.IntrospectResponse{ + inactiveResponse := &IntrospectResponse{ Active: false, } s.introspectionServiceMock.On("IntrospectToken", "invalid-token", "").Return(inactiveResponse, nil) diff --git a/backend/internal/system/services/tokenintrospectservice.go b/backend/internal/oauth/oauth2/introspect/init.go similarity index 52% rename from backend/internal/system/services/tokenintrospectservice.go rename to backend/internal/oauth/oauth2/introspect/init.go index 1619c022..8476c6aa 100644 --- a/backend/internal/system/services/tokenintrospectservice.go +++ b/backend/internal/oauth/oauth2/introspect/init.go @@ -16,39 +16,25 @@ * under the License. */ -package services +package introspect import ( "net/http" - "github.com/asgardeo/thunder/internal/oauth/oauth2/introspect" "github.com/asgardeo/thunder/internal/system/jwt" "github.com/asgardeo/thunder/internal/system/middleware" ) -// TODO: Introspection endpoint MUST require authentication and authorization. -// Implement this when the support is added. - -// TokenIntrospectionAPIService defines the API service for handling OAuth 2.0 token introspection requests. -type TokenIntrospectionAPIService struct { - introspectHandler *introspect.TokenIntrospectionHandler -} - -// NewIntrospectionAPIService creates a new instance of IntrospectionAPIService. -func NewIntrospectionAPIService(mux *http.ServeMux) ServiceInterface { - jwtService := jwt.GetJWTService() - introspectionService := introspect.NewTokenIntrospectionService(jwtService) - - instance := &TokenIntrospectionAPIService{ - introspectHandler: introspect.NewTokenIntrospectionHandler(introspectionService), - } - instance.RegisterRoutes(mux) - - return instance +// Initialize initializes the token introspection handler and registers its routes. +func Initialize(mux *http.ServeMux, jwtService jwt.JWTServiceInterface) TokenIntrospectionServiceInterface { + introspectionService := newTokenIntrospectionService(jwtService) + introspectHandler := newTokenIntrospectionHandler(introspectionService) + registerRoutes(mux, introspectHandler) + return introspectionService } -// RegisterRoutes registers the routes for the IntrospectionAPIService. -func (s *TokenIntrospectionAPIService) RegisterRoutes(mux *http.ServeMux) { +// registerRoutes registers the routes for the IntrospectionAPIService. +func registerRoutes(mux *http.ServeMux, introspectHandler *tokenIntrospectionHandler) { opts := middleware.CORSOptions{ AllowedMethods: "POST, OPTIONS", AllowedHeaders: "Content-Type, Authorization", @@ -56,7 +42,7 @@ func (s *TokenIntrospectionAPIService) RegisterRoutes(mux *http.ServeMux) { } mux.HandleFunc(middleware.WithCORS("POST /oauth2/introspect", - s.introspectHandler.HandleIntrospect, opts)) + introspectHandler.HandleIntrospect, opts)) mux.HandleFunc(middleware.WithCORS("OPTIONS /oauth2/introspect", func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) diff --git a/backend/internal/oauth/oauth2/introspect/service.go b/backend/internal/oauth/oauth2/introspect/service.go index 917484ee..382d6af8 100644 --- a/backend/internal/oauth/oauth2/introspect/service.go +++ b/backend/internal/oauth/oauth2/introspect/service.go @@ -32,21 +32,21 @@ type TokenIntrospectionServiceInterface interface { IntrospectToken(token, tokenTypeHint string) (*IntrospectResponse, error) } -// TokenIntrospectionService implements the TokenIntrospectionServiceInterface. -type TokenIntrospectionService struct { +// tokenIntrospectionService implements the TokenIntrospectionServiceInterface. +type tokenIntrospectionService struct { jwtService jwt.JWTServiceInterface } -// NewTokenIntrospectionService creates a new TokenIntrospectionService instance. -func NewTokenIntrospectionService(jwtService jwt.JWTServiceInterface) TokenIntrospectionServiceInterface { - return &TokenIntrospectionService{ +// newTokenIntrospectionService creates a new tokenIntrospectionService instance (internal use). +func newTokenIntrospectionService(jwtService jwt.JWTServiceInterface) TokenIntrospectionServiceInterface { + return &tokenIntrospectionService{ jwtService: jwtService, } } // IntrospectToken validates and introspects the token. It only returns an error if a server error occurs. // All other failures are treated as inactive token as defined in the RFC 7662. -func (s *TokenIntrospectionService) IntrospectToken(token, tokenTypeHint string) (*IntrospectResponse, error) { +func (s *tokenIntrospectionService) IntrospectToken(token, tokenTypeHint string) (*IntrospectResponse, error) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "TokenIntrospectionService")) if token == "" { @@ -74,7 +74,7 @@ func (s *TokenIntrospectionService) IntrospectToken(token, tokenTypeHint string) } // validateToken verifies the signature and validity of the token. -func (s *TokenIntrospectionService) validateToken(logger *log.Logger, token string) bool { +func (s *tokenIntrospectionService) validateToken(logger *log.Logger, token string) bool { if err := s.jwtService.VerifyJWT(token, "", ""); err != nil { logger.Debug("Failed to verify refresh token", log.Error(err)) return false @@ -83,7 +83,7 @@ func (s *TokenIntrospectionService) validateToken(logger *log.Logger, token stri } // prepareValidResponse prepares the response for a valid token introspection. -func (s *TokenIntrospectionService) prepareValidResponse(payload map[string]interface{}) *IntrospectResponse { +func (s *tokenIntrospectionService) prepareValidResponse(payload map[string]interface{}) *IntrospectResponse { response := &IntrospectResponse{ Active: true, // TODO: Revisit if/when adding support for other token types. diff --git a/backend/internal/oauth/oauth2/introspect/service_test.go b/backend/internal/oauth/oauth2/introspect/service_test.go index 9d2f8e81..7ade03c2 100644 --- a/backend/internal/oauth/oauth2/introspect/service_test.go +++ b/backend/internal/oauth/oauth2/introspect/service_test.go @@ -16,7 +16,7 @@ * under the License. */ -package introspect_test +package introspect import ( "crypto" @@ -30,7 +30,6 @@ import ( "time" "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" - "github.com/asgardeo/thunder/internal/oauth/oauth2/introspect" "github.com/asgardeo/thunder/tests/mocks/jwtmock" "github.com/stretchr/testify/assert" @@ -41,7 +40,7 @@ import ( type TokenIntrospectionServiceTestSuite struct { suite.Suite jwtServiceMock *jwtmock.JWTServiceInterfaceMock - introspectService introspect.TokenIntrospectionServiceInterface + introspectService TokenIntrospectionServiceInterface validToken string expiredToken string notBeforeToken string @@ -63,7 +62,7 @@ func (s *TokenIntrospectionServiceTestSuite) SetupTest() { s.T().Fatal("Error generating RSA key:", err) } - s.introspectService = introspect.NewTokenIntrospectionService(s.jwtServiceMock) + s.introspectService = newTokenIntrospectionService(s.jwtServiceMock) s.validToken = s.createValidToken() s.expiredToken = s.createExpiredToken() diff --git a/backend/internal/oauth/oauth2/model/parameter.go b/backend/internal/oauth/oauth2/model/parameter.go index b9fe18ce..356ce0f6 100644 --- a/backend/internal/oauth/oauth2/model/parameter.go +++ b/backend/internal/oauth/oauth2/model/parameter.go @@ -20,7 +20,6 @@ package model // OAuthParameters represents the parameters required for OAuth2 authorization. type OAuthParameters struct { - SessionDataKey string State string ClientID string RedirectURI string diff --git a/backend/internal/system/services/tokenservice.go b/backend/internal/oauth/oauth2/token/init.go similarity index 50% rename from backend/internal/system/services/tokenservice.go rename to backend/internal/oauth/oauth2/token/init.go index 9c483125..74b38235 100644 --- a/backend/internal/system/services/tokenservice.go +++ b/backend/internal/oauth/oauth2/token/init.go @@ -16,36 +16,35 @@ * under the License. */ -package services +package token import ( "net/http" - "github.com/asgardeo/thunder/internal/oauth/oauth2/token" + "github.com/asgardeo/thunder/internal/application" + "github.com/asgardeo/thunder/internal/oauth/oauth2/granthandlers" + "github.com/asgardeo/thunder/internal/oauth/scope" "github.com/asgardeo/thunder/internal/system/middleware" ) -// TokenService defines the service for handling OAuth2 token requests. -type TokenService struct { - tokenHandler token.TokenHandlerInterface +// Initialize initializes the token handler and registers its routes. +func Initialize( + mux *http.ServeMux, + appService application.ApplicationServiceInterface, + grantHandlerProvider granthandlers.GrantHandlerProviderInterface, + scopeValidator scope.ScopeValidatorInterface, +) TokenHandlerInterface { + tokenHandler := newTokenHandler(appService, grantHandlerProvider, scopeValidator) + registerRoutes(mux, tokenHandler) + return tokenHandler } -// NewTokenService creates a new instance of TokenService. -func NewTokenService(mux *http.ServeMux) ServiceInterface { - instance := &TokenService{ - tokenHandler: token.NewTokenHandler(), - } - instance.RegisterRoutes(mux) - - return instance -} - -// RegisterRoutes registers the routes for the TokenService. -func (s *TokenService) RegisterRoutes(mux *http.ServeMux) { +// registerRoutes registers the routes for the TokenService. +func registerRoutes(mux *http.ServeMux, tokenHandler TokenHandlerInterface) { opts := middleware.CORSOptions{ AllowedMethods: "POST", AllowedHeaders: "Content-Type, Authorization", AllowCredentials: true, } - mux.HandleFunc(middleware.WithCORS("POST /oauth2/token", s.tokenHandler.HandleTokenRequest, opts)) + mux.HandleFunc(middleware.WithCORS("POST /oauth2/token", tokenHandler.HandleTokenRequest, opts)) } diff --git a/backend/internal/oauth/oauth2/token/tokenhandler.go b/backend/internal/oauth/oauth2/token/tokenhandler.go index acadddd0..c9232137 100644 --- a/backend/internal/oauth/oauth2/token/tokenhandler.go +++ b/backend/internal/oauth/oauth2/token/tokenhandler.go @@ -32,7 +32,7 @@ import ( "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" "github.com/asgardeo/thunder/internal/oauth/oauth2/granthandlers" "github.com/asgardeo/thunder/internal/oauth/oauth2/model" - scopeprovider "github.com/asgardeo/thunder/internal/oauth/scope/provider" + "github.com/asgardeo/thunder/internal/oauth/scope" "github.com/asgardeo/thunder/internal/system/log" ) @@ -41,25 +41,29 @@ type TokenHandlerInterface interface { HandleTokenRequest(w http.ResponseWriter, r *http.Request) } -// TokenHandler implements the TokenHandlerInterface. -type TokenHandler struct { - GrantHandlerProvider granthandlers.GrantHandlerProviderInterface - ApplicationProvider application.ApplicationProviderInterface - ScopeValidatorProvider scopeprovider.ScopeValidatorProviderInterface +// tokenHandler implements the TokenHandlerInterface. +type tokenHandler struct { + appService application.ApplicationServiceInterface + grantHandlerProvider granthandlers.GrantHandlerProviderInterface + scopeValidator scope.ScopeValidatorInterface } -// NewTokenHandler creates a new instance of TokenHandler. -func NewTokenHandler() TokenHandlerInterface { - return &TokenHandler{ - GrantHandlerProvider: granthandlers.NewGrantHandlerProvider(), - ApplicationProvider: application.NewApplicationProvider(), - ScopeValidatorProvider: scopeprovider.NewScopeValidatorProvider(), +// newTokenHandler creates a new instance of tokenHandler. +func newTokenHandler( + appService application.ApplicationServiceInterface, + grantHandlerProvider granthandlers.GrantHandlerProviderInterface, + scopeValidator scope.ScopeValidatorInterface, +) TokenHandlerInterface { + return &tokenHandler{ + appService: appService, + grantHandlerProvider: grantHandlerProvider, + scopeValidator: scopeValidator, } } // HandleTokenRequest handles the token request for OAuth 2.0. // It validates the client credentials and delegates to the appropriate grant handler. -func (th *TokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Request) { +func (th *tokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Request) { logger := log.GetLogger().With(log.String(log.LoggerKeyComponentName, "TokenHandler")) // Parse the form data from the request body. @@ -83,7 +87,7 @@ func (th *TokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Reques return } - grantHandler, handlerErr := th.GrantHandlerProvider.GetGrantHandler(grantType) + grantHandler, handlerErr := th.grantHandlerProvider.GetGrantHandler(grantType) if handlerErr != nil { if errors.Is(handlerErr, constants.UnSupportedGrantTypeError) { utils.WriteJSONError(w, constants.ErrorUnsupportedGrantType, "Unsupported grant type", @@ -102,8 +106,7 @@ func (th *TokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Reques } // Retrieve the OAuth application based on the client id. - appService := th.ApplicationProvider.GetApplicationService() - oauthApp, err := appService.GetOAuthApplication(clientID) + oauthApp, err := th.appService.GetOAuthApplication(clientID) if err != nil || oauthApp == nil { utils.WriteJSONError(w, constants.ErrorInvalidClient, "Invalid client credentials", http.StatusUnauthorized, nil) @@ -163,8 +166,7 @@ func (th *TokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Reques } // Validate and filter scopes. - scopeValidator := th.ScopeValidatorProvider.GetScopeValidator() - validScopes, scopeError := scopeValidator.ValidateScopes(tokenRequest.Scope, oauthApp.ClientID) + validScopes, scopeError := th.scopeValidator.ValidateScopes(tokenRequest.Scope, oauthApp.ClientID) if scopeError != nil { utils.WriteJSONError(w, scopeError.Error, scopeError.ErrorDescription, http.StatusBadRequest, nil) return @@ -187,7 +189,7 @@ func (th *TokenHandler) HandleTokenRequest(w http.ResponseWriter, r *http.Reques logger.Debug("Issuing refresh token for the token request", log.String("client_id", clientID), log.String("grant_type", grantTypeStr)) - refreshGrantHandler, handlerErr := th.GrantHandlerProvider.GetGrantHandler(constants.GrantTypeRefreshToken) + refreshGrantHandler, handlerErr := th.grantHandlerProvider.GetGrantHandler(constants.GrantTypeRefreshToken) if handlerErr != nil { logger.Error("Failed to get refresh grant handler", log.Error(handlerErr)) utils.WriteJSONError(w, constants.ErrorServerError, diff --git a/backend/internal/oauth/oauth2/token/tokenhandler_test.go b/backend/internal/oauth/oauth2/token/tokenhandler_test.go index b0f9b753..ff1cef0b 100644 --- a/backend/internal/oauth/oauth2/token/tokenhandler_test.go +++ b/backend/internal/oauth/oauth2/token/tokenhandler_test.go @@ -29,11 +29,24 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/suite" + applicationmodel "github.com/asgardeo/thunder/internal/application/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/constants" "github.com/asgardeo/thunder/internal/system/config" + "github.com/asgardeo/thunder/tests/mocks/applicationmock" + "github.com/asgardeo/thunder/tests/mocks/jwtmock" + "github.com/asgardeo/thunder/tests/mocks/oauth/oauth2/granthandlersmock" + "github.com/asgardeo/thunder/tests/mocks/oauth/scopemock" + "github.com/asgardeo/thunder/tests/mocks/usermock" ) type TokenHandlerTestSuite struct { suite.Suite + mockJWTService *jwtmock.JWTServiceInterfaceMock + mockUserService *usermock.UserServiceInterfaceMock + mockAppService *applicationmock.ApplicationServiceInterfaceMock + mockGrantProvider *granthandlersmock.GrantHandlerProviderInterfaceMock + mockScopeValidator *scopemock.ScopeValidatorInterfaceMock + mockGrantHandler *granthandlersmock.GrantHandlerInterfaceMock } func TestTokenHandlerSuite(t *testing.T) { @@ -48,16 +61,27 @@ func (suite *TokenHandlerTestSuite) SetupTest() { }, } _ = config.InitializeThunderRuntime("test", testConfig) + suite.mockJWTService = &jwtmock.JWTServiceInterfaceMock{} + suite.mockUserService = usermock.NewUserServiceInterfaceMock(suite.T()) + suite.mockGrantProvider = granthandlersmock.NewGrantHandlerProviderInterfaceMock(suite.T()) + suite.mockAppService = applicationmock.NewApplicationServiceInterfaceMock(suite.T()) + suite.mockScopeValidator = scopemock.NewScopeValidatorInterfaceMock(suite.T()) + suite.mockGrantHandler = granthandlersmock.NewGrantHandlerInterfaceMock(suite.T()) + + // Setup common mock for GetGrantHandler that can be used across tests + // Using Maybe() allows tests to override this if needed + suite.mockGrantProvider.On("GetGrantHandler", constants.GrantTypeAuthorizationCode). + Return(suite.mockGrantHandler, nil).Maybe() } -func (suite *TokenHandlerTestSuite) TestNewTokenHandler() { - handler := NewTokenHandler() +func (suite *TokenHandlerTestSuite) TestnewTokenHandler() { + handler := newTokenHandler(suite.mockAppService, suite.mockGrantProvider, suite.mockScopeValidator) assert.NotNil(suite.T(), handler) assert.Implements(suite.T(), (*TokenHandlerInterface)(nil), handler) } func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_InvalidFormData() { - handler := NewTokenHandler() + handler := newTokenHandler(suite.mockAppService, suite.mockGrantProvider, suite.mockScopeValidator) req, _ := http.NewRequest("POST", "/token", strings.NewReader("invalid-form-data%")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -75,31 +99,18 @@ func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_InvalidFormData() { } func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_MissingGrantType() { - handler := NewTokenHandler() formData := url.Values{} formData.Set("client_id", "test-client-id") formData.Set("client_secret", "test-secret") - req, _ := http.NewRequest("POST", "/token", strings.NewReader(formData.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - rr := httptest.NewRecorder() - - handler.HandleTokenRequest(rr, req) - - assert.Equal(suite.T(), http.StatusBadRequest, rr.Code) - - var response map[string]interface{} - err := json.Unmarshal(rr.Body.Bytes(), &response) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "invalid_request", response["error"]) - assert.Equal(suite.T(), "Missing grant_type parameter", response["error_description"]) + suite.testTokenRequestError(formData, http.StatusBadRequest, "invalid_request", + "Missing grant_type parameter") } // Helper function to test token request error scenarios func (suite *TokenHandlerTestSuite) testTokenRequestError(formData url.Values, expectedStatusCode int, expectedError, expectedErrorDescription string) { - handler := NewTokenHandler() + handler := newTokenHandler(suite.mockAppService, suite.mockGrantProvider, suite.mockScopeValidator) req, _ := http.NewRequest("POST", "/token", strings.NewReader(formData.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -128,33 +139,29 @@ func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_InvalidGrantType() { } func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_MissingClientID() { - handler := NewTokenHandler() formData := url.Values{} formData.Set("grant_type", "authorization_code") formData.Set("client_secret", "test-secret") - req, _ := http.NewRequest("POST", "/token", strings.NewReader(formData.Encode())) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - - rr := httptest.NewRecorder() - - handler.HandleTokenRequest(rr, req) - - assert.Equal(suite.T(), http.StatusUnauthorized, rr.Code) - - var response map[string]interface{} - err := json.Unmarshal(rr.Body.Bytes(), &response) - assert.NoError(suite.T(), err) - assert.Equal(suite.T(), "invalid_client", response["error"]) - assert.Equal(suite.T(), "Missing client_id parameter", response["error_description"]) + suite.testTokenRequestError(formData, http.StatusUnauthorized, "invalid_client", + "Missing client_id parameter") } func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_MissingClientSecret() { - handler := NewTokenHandler() + handler := newTokenHandler(suite.mockAppService, suite.mockGrantProvider, suite.mockScopeValidator) formData := url.Values{} formData.Set("grant_type", "authorization_code") formData.Set("client_id", "test-client-id") + // Mock GetOAuthApplication to return a valid app that requires client_secret_post + mockApp := &applicationmodel.OAuthAppConfigProcessedDTO{ + ClientID: "test-client-id", + HashedClientSecret: "hashed-secret", + TokenEndpointAuthMethod: []constants.TokenEndpointAuthMethod{constants.TokenEndpointAuthMethodClientSecretPost}, + GrantTypes: []constants.GrantType{constants.GrantTypeAuthorizationCode}, + } + suite.mockAppService.On("GetOAuthApplication", "test-client-id").Return(mockApp, nil).Once() + req, _ := http.NewRequest("POST", "/token", strings.NewReader(formData.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -168,10 +175,14 @@ func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_MissingClientSecret() err := json.Unmarshal(rr.Body.Bytes(), &response) assert.NoError(suite.T(), err) assert.Equal(suite.T(), "invalid_client", response["error"]) - assert.Equal(suite.T(), "Invalid client credentials", response["error_description"]) + // The error message should mention client_secret or be about missing credentials + assert.Contains(suite.T(), response["error_description"], "client_secret") } func (suite *TokenHandlerTestSuite) TestHandleTokenRequest_InvalidClient() { + // Mock GetOAuthApplication to return nil for invalid client + suite.mockAppService.On("GetOAuthApplication", "invalid-client").Return(nil, nil).Once() + formData := url.Values{} formData.Set("grant_type", "authorization_code") formData.Set("client_id", "invalid-client") diff --git a/backend/internal/oauth/session/utils/sessionutils.go b/backend/internal/oauth/scope/init.go similarity index 69% rename from backend/internal/oauth/session/utils/sessionutils.go rename to backend/internal/oauth/scope/init.go index f0beeb66..e42905cb 100644 --- a/backend/internal/oauth/session/utils/sessionutils.go +++ b/backend/internal/oauth/scope/init.go @@ -16,12 +16,9 @@ * under the License. */ -// Package utils provides utility functions for session management. -package utils +package scope -import "github.com/asgardeo/thunder/internal/system/utils" - -// GenerateNewSessionDataKey generates and returns a new session data key. -func GenerateNewSessionDataKey() string { - return utils.GenerateUUID() +// Initialize initializes and returns a new scope validator. +func Initialize() ScopeValidatorInterface { + return newAPIScopeValidator() } diff --git a/backend/internal/oauth/scope/provider/validatorprovider.go b/backend/internal/oauth/scope/provider/validatorprovider.go deleted file mode 100644 index a7630f51..00000000 --- a/backend/internal/oauth/scope/provider/validatorprovider.go +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). - * - * WSO2 LLC. licenses this file to you under the Apache License, - * Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Package provider provides functionality for managing scope validator instances. -package provider - -import "github.com/asgardeo/thunder/internal/oauth/scope/validator" - -// ScopeValidatorProviderInterface defines the interface for providing a scope validator. -type ScopeValidatorProviderInterface interface { - GetScopeValidator() validator.ScopeValidatorInterface -} - -// ScopeValidatorProvider is the default implementation of the ScopeValidatorProviderInterface. -type ScopeValidatorProvider struct{} - -// NewScopeValidatorProvider creates a new instance of ScopeValidatorProvider. -func NewScopeValidatorProvider() ScopeValidatorProviderInterface { - return &ScopeValidatorProvider{} -} - -// GetScopeValidator returns the scope validator instance. -func (svp *ScopeValidatorProvider) GetScopeValidator() validator.ScopeValidatorInterface { - return validator.NewAPIScopeValidator() -} diff --git a/backend/internal/oauth/scope/provider/validatorprovider_test.go b/backend/internal/oauth/scope/provider/validatorprovider_test.go deleted file mode 100644 index 91a0f1f4..00000000 --- a/backend/internal/oauth/scope/provider/validatorprovider_test.go +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). - * - * WSO2 LLC. licenses this file to you under the Apache License, - * Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package provider - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" - - "github.com/asgardeo/thunder/internal/oauth/scope/validator" -) - -type ScopeValidatorProviderTestSuite struct { - suite.Suite - provider ScopeValidatorProviderInterface -} - -func TestScopeValidatorProviderSuite(t *testing.T) { - suite.Run(t, new(ScopeValidatorProviderTestSuite)) -} - -func (suite *ScopeValidatorProviderTestSuite) SetupTest() { - suite.provider = NewScopeValidatorProvider() -} - -func (suite *ScopeValidatorProviderTestSuite) TestNewScopeValidatorProvider() { - provider := NewScopeValidatorProvider() - assert.NotNil(suite.T(), provider) - assert.IsType(suite.T(), &ScopeValidatorProvider{}, provider) -} - -func (suite *ScopeValidatorProviderTestSuite) TestGetScopeValidator() { - scopeValidator := suite.provider.GetScopeValidator() - assert.NotNil(suite.T(), scopeValidator) - assert.Implements(suite.T(), (*validator.ScopeValidatorInterface)(nil), scopeValidator) -} - -func (suite *ScopeValidatorProviderTestSuite) TestGetScopeValidatorReturnsConsistentInstance() { - validator1 := suite.provider.GetScopeValidator() - validator2 := suite.provider.GetScopeValidator() - - assert.NotNil(suite.T(), validator1) - assert.NotNil(suite.T(), validator2) - assert.IsType(suite.T(), validator1, validator2) -} - -func (suite *ScopeValidatorProviderTestSuite) TestScopeValidatorProviderInterface() { - var _ ScopeValidatorProviderInterface = &ScopeValidatorProvider{} - - var provider ScopeValidatorProviderInterface = NewScopeValidatorProvider() - scopeValidator := provider.GetScopeValidator() - assert.NotNil(suite.T(), scopeValidator) - assert.Implements(suite.T(), (*validator.ScopeValidatorInterface)(nil), scopeValidator) -} - -func (suite *ScopeValidatorProviderTestSuite) TestGetScopeValidatorFunctionality() { - scopeValidator := suite.provider.GetScopeValidator() - - result, err := scopeValidator.ValidateScopes("read write", "test-client") - assert.Equal(suite.T(), "read write", result) - assert.Nil(suite.T(), err) -} diff --git a/backend/internal/oauth/scope/validator/scopevalidator.go b/backend/internal/oauth/scope/validator.go similarity index 76% rename from backend/internal/oauth/scope/validator/scopevalidator.go rename to backend/internal/oauth/scope/validator.go index b836195a..ba462041 100644 --- a/backend/internal/oauth/scope/validator/scopevalidator.go +++ b/backend/internal/oauth/scope/validator.go @@ -16,8 +16,8 @@ * under the License. */ -// Package validator provides functionality for validating scopes. -package validator +// Package scope provides functionality for validating scopes. +package scope // ScopeError represents an error during scope validation. type ScopeError struct { @@ -30,16 +30,16 @@ type ScopeValidatorInterface interface { ValidateScopes(requestedScopes, clientID string) (string, *ScopeError) } -// APIScopeValidator is the implementation of API scope validation. -type APIScopeValidator struct{} +// apiScopeValidator is the implementation of API scope validation. +type apiScopeValidator struct{} -// NewAPIScopeValidator creates a new instance of the APIScopeValidator. -func NewAPIScopeValidator() *APIScopeValidator { - return &APIScopeValidator{} +// newAPIScopeValidator creates a new instance of the apiScopeValidator. +func newAPIScopeValidator() *apiScopeValidator { + return &apiScopeValidator{} } // ValidateScopes validates and filters the requested scopes against the authorized scopes for the application. -func (sv *APIScopeValidator) ValidateScopes(requestedScopes, clientID string) (string, *ScopeError) { +func (sv *apiScopeValidator) ValidateScopes(requestedScopes, clientID string) (string, *ScopeError) { if requestedScopes == "" { return "", nil } diff --git a/backend/internal/oauth/scope/validator/scopevalidator_test.go b/backend/internal/oauth/scope/validator_test.go similarity index 91% rename from backend/internal/oauth/scope/validator/scopevalidator_test.go rename to backend/internal/oauth/scope/validator_test.go index ee8b42d4..3af62d38 100644 --- a/backend/internal/oauth/scope/validator/scopevalidator_test.go +++ b/backend/internal/oauth/scope/validator_test.go @@ -16,7 +16,7 @@ * under the License. */ -package validator +package scope import ( "testing" @@ -35,13 +35,13 @@ func TestScopeValidatorSuite(t *testing.T) { } func (suite *ScopeValidatorTestSuite) SetupTest() { - suite.validator = NewAPIScopeValidator() + suite.validator = newAPIScopeValidator() } func (suite *ScopeValidatorTestSuite) TestNewAPIScopeValidator() { - validator := NewAPIScopeValidator() + validator := newAPIScopeValidator() assert.NotNil(suite.T(), validator) - assert.IsType(suite.T(), &APIScopeValidator{}, validator) + assert.IsType(suite.T(), &apiScopeValidator{}, validator) } func (suite *ScopeValidatorTestSuite) TestValidateScopes() { @@ -100,9 +100,9 @@ func (suite *ScopeValidatorTestSuite) TestValidateScopes() { } func (suite *ScopeValidatorTestSuite) TestValidateScopesInterface() { - var _ ScopeValidatorInterface = &APIScopeValidator{} + var _ ScopeValidatorInterface = &apiScopeValidator{} - validator := NewAPIScopeValidator() + validator := newAPIScopeValidator() scopes, err := validator.ValidateScopes("test", "client") assert.Equal(suite.T(), "test", scopes) assert.Nil(suite.T(), err) diff --git a/backend/internal/oauth/session/model/sessiondata.go b/backend/internal/oauth/session/model/sessiondata.go deleted file mode 100644 index d5abaf89..00000000 --- a/backend/internal/oauth/session/model/sessiondata.go +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). - * - * WSO2 LLC. licenses this file to you under the Apache License, - * Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -// Package model defines the data structures for managing auth session data. -package model - -import ( - "time" - - authncm "github.com/asgardeo/thunder/internal/authn/common" - oauthmodel "github.com/asgardeo/thunder/internal/oauth/oauth2/model" -) - -// SessionData represents the session data for the authentication. -type SessionData struct { - OAuthParameters oauthmodel.OAuthParameters - AuthTime time.Time - AuthenticatedUser authncm.AuthenticatedUser -} diff --git a/backend/internal/oauth/session/utils/sessionutils_test.go b/backend/internal/oauth/session/utils/sessionutils_test.go deleted file mode 100644 index 49fef89b..00000000 --- a/backend/internal/oauth/session/utils/sessionutils_test.go +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Copyright (c) 2025, WSO2 LLC. (https://www.wso2.com). - * - * WSO2 LLC. licenses this file to you under the Apache License, - * Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package utils - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" -) - -type SessionUtilsTestSuite struct { - suite.Suite -} - -func TestSessionUtilsSuite(t *testing.T) { - suite.Run(t, new(SessionUtilsTestSuite)) -} - -func (suite *SessionUtilsTestSuite) TestGenerateNewSessionDataKey() { - sessionKey := GenerateNewSessionDataKey() - assert.NotEmpty(suite.T(), sessionKey) - assert.Len(suite.T(), sessionKey, 36) - assert.Contains(suite.T(), sessionKey, "-") -} - -func (suite *SessionUtilsTestSuite) TestGenerateNewSessionDataKeyUniqueness() { - keys := make(map[string]bool) - numKeys := 100 - - for i := 0; i < numKeys; i++ { - key := GenerateNewSessionDataKey() - assert.False(suite.T(), keys[key], "Duplicate session key generated: %s", key) - keys[key] = true - assert.NotEmpty(suite.T(), key) - } - - assert.Len(suite.T(), keys, numKeys) -} - -func (suite *SessionUtilsTestSuite) TestGenerateNewSessionDataKeyFormat() { - sessionKey := GenerateNewSessionDataKey() - - assert.Len(suite.T(), sessionKey, 36) - assert.Equal(suite.T(), "-", string(sessionKey[8])) - assert.Equal(suite.T(), "-", string(sessionKey[13])) - assert.Equal(suite.T(), "-", string(sessionKey[18])) - assert.Equal(suite.T(), "-", string(sessionKey[23])) -} - -func (suite *SessionUtilsTestSuite) TestGenerateNewSessionDataKeyConsistentLength() { - for i := 0; i < 10; i++ { - key := GenerateNewSessionDataKey() - assert.Len(suite.T(), key, 36, "Session key should always be 36 characters") - } -} diff --git a/backend/tests/mocks/applicationmock/ApplicationServiceInterface_mock.go b/backend/tests/mocks/applicationmock/ApplicationServiceInterface_mock.go new file mode 100644 index 00000000..c82643b8 --- /dev/null +++ b/backend/tests/mocks/applicationmock/ApplicationServiceInterface_mock.go @@ -0,0 +1,410 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package applicationmock + +import ( + "github.com/asgardeo/thunder/internal/application/model" + "github.com/asgardeo/thunder/internal/system/error/serviceerror" + mock "github.com/stretchr/testify/mock" +) + +// NewApplicationServiceInterfaceMock creates a new instance of ApplicationServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewApplicationServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *ApplicationServiceInterfaceMock { + mock := &ApplicationServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// ApplicationServiceInterfaceMock is an autogenerated mock type for the ApplicationServiceInterface type +type ApplicationServiceInterfaceMock struct { + mock.Mock +} + +type ApplicationServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *ApplicationServiceInterfaceMock) EXPECT() *ApplicationServiceInterfaceMock_Expecter { + return &ApplicationServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// CreateApplication provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) CreateApplication(app *model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError) { + ret := _mock.Called(app) + + if len(ret) == 0 { + panic("no return value specified for CreateApplication") + } + + var r0 *model.ApplicationDTO + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(*model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError)); ok { + return returnFunc(app) + } + if returnFunc, ok := ret.Get(0).(func(*model.ApplicationDTO) *model.ApplicationDTO); ok { + r0 = returnFunc(app) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ApplicationDTO) + } + } + if returnFunc, ok := ret.Get(1).(func(*model.ApplicationDTO) *serviceerror.ServiceError); ok { + r1 = returnFunc(app) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// ApplicationServiceInterfaceMock_CreateApplication_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CreateApplication' +type ApplicationServiceInterfaceMock_CreateApplication_Call struct { + *mock.Call +} + +// CreateApplication is a helper method to define mock.On call +// - app *model.ApplicationDTO +func (_e *ApplicationServiceInterfaceMock_Expecter) CreateApplication(app interface{}) *ApplicationServiceInterfaceMock_CreateApplication_Call { + return &ApplicationServiceInterfaceMock_CreateApplication_Call{Call: _e.mock.On("CreateApplication", app)} +} + +func (_c *ApplicationServiceInterfaceMock_CreateApplication_Call) Run(run func(app *model.ApplicationDTO)) *ApplicationServiceInterfaceMock_CreateApplication_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 *model.ApplicationDTO + if args[0] != nil { + arg0 = args[0].(*model.ApplicationDTO) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_CreateApplication_Call) Return(applicationDTO *model.ApplicationDTO, serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_CreateApplication_Call { + _c.Call.Return(applicationDTO, serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_CreateApplication_Call) RunAndReturn(run func(app *model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError)) *ApplicationServiceInterfaceMock_CreateApplication_Call { + _c.Call.Return(run) + return _c +} + +// DeleteApplication provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) DeleteApplication(appID string) *serviceerror.ServiceError { + ret := _mock.Called(appID) + + if len(ret) == 0 { + panic("no return value specified for DeleteApplication") + } + + var r0 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) *serviceerror.ServiceError); ok { + r0 = returnFunc(appID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*serviceerror.ServiceError) + } + } + return r0 +} + +// ApplicationServiceInterfaceMock_DeleteApplication_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'DeleteApplication' +type ApplicationServiceInterfaceMock_DeleteApplication_Call struct { + *mock.Call +} + +// DeleteApplication is a helper method to define mock.On call +// - appID string +func (_e *ApplicationServiceInterfaceMock_Expecter) DeleteApplication(appID interface{}) *ApplicationServiceInterfaceMock_DeleteApplication_Call { + return &ApplicationServiceInterfaceMock_DeleteApplication_Call{Call: _e.mock.On("DeleteApplication", appID)} +} + +func (_c *ApplicationServiceInterfaceMock_DeleteApplication_Call) Run(run func(appID string)) *ApplicationServiceInterfaceMock_DeleteApplication_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_DeleteApplication_Call) Return(serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_DeleteApplication_Call { + _c.Call.Return(serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_DeleteApplication_Call) RunAndReturn(run func(appID string) *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_DeleteApplication_Call { + _c.Call.Return(run) + return _c +} + +// GetApplication provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) GetApplication(appID string) (*model.ApplicationProcessedDTO, *serviceerror.ServiceError) { + ret := _mock.Called(appID) + + if len(ret) == 0 { + panic("no return value specified for GetApplication") + } + + var r0 *model.ApplicationProcessedDTO + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) (*model.ApplicationProcessedDTO, *serviceerror.ServiceError)); ok { + return returnFunc(appID) + } + if returnFunc, ok := ret.Get(0).(func(string) *model.ApplicationProcessedDTO); ok { + r0 = returnFunc(appID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ApplicationProcessedDTO) + } + } + if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { + r1 = returnFunc(appID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// ApplicationServiceInterfaceMock_GetApplication_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetApplication' +type ApplicationServiceInterfaceMock_GetApplication_Call struct { + *mock.Call +} + +// GetApplication is a helper method to define mock.On call +// - appID string +func (_e *ApplicationServiceInterfaceMock_Expecter) GetApplication(appID interface{}) *ApplicationServiceInterfaceMock_GetApplication_Call { + return &ApplicationServiceInterfaceMock_GetApplication_Call{Call: _e.mock.On("GetApplication", appID)} +} + +func (_c *ApplicationServiceInterfaceMock_GetApplication_Call) Run(run func(appID string)) *ApplicationServiceInterfaceMock_GetApplication_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetApplication_Call) Return(applicationProcessedDTO *model.ApplicationProcessedDTO, serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_GetApplication_Call { + _c.Call.Return(applicationProcessedDTO, serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetApplication_Call) RunAndReturn(run func(appID string) (*model.ApplicationProcessedDTO, *serviceerror.ServiceError)) *ApplicationServiceInterfaceMock_GetApplication_Call { + _c.Call.Return(run) + return _c +} + +// GetApplicationList provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) GetApplicationList() (*model.ApplicationListResponse, *serviceerror.ServiceError) { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for GetApplicationList") + } + + var r0 *model.ApplicationListResponse + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func() (*model.ApplicationListResponse, *serviceerror.ServiceError)); ok { + return returnFunc() + } + if returnFunc, ok := ret.Get(0).(func() *model.ApplicationListResponse); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ApplicationListResponse) + } + } + if returnFunc, ok := ret.Get(1).(func() *serviceerror.ServiceError); ok { + r1 = returnFunc() + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// ApplicationServiceInterfaceMock_GetApplicationList_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetApplicationList' +type ApplicationServiceInterfaceMock_GetApplicationList_Call struct { + *mock.Call +} + +// GetApplicationList is a helper method to define mock.On call +func (_e *ApplicationServiceInterfaceMock_Expecter) GetApplicationList() *ApplicationServiceInterfaceMock_GetApplicationList_Call { + return &ApplicationServiceInterfaceMock_GetApplicationList_Call{Call: _e.mock.On("GetApplicationList")} +} + +func (_c *ApplicationServiceInterfaceMock_GetApplicationList_Call) Run(run func()) *ApplicationServiceInterfaceMock_GetApplicationList_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetApplicationList_Call) Return(applicationListResponse *model.ApplicationListResponse, serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_GetApplicationList_Call { + _c.Call.Return(applicationListResponse, serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetApplicationList_Call) RunAndReturn(run func() (*model.ApplicationListResponse, *serviceerror.ServiceError)) *ApplicationServiceInterfaceMock_GetApplicationList_Call { + _c.Call.Return(run) + return _c +} + +// GetOAuthApplication provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) GetOAuthApplication(clientID string) (*model.OAuthAppConfigProcessedDTO, *serviceerror.ServiceError) { + ret := _mock.Called(clientID) + + if len(ret) == 0 { + panic("no return value specified for GetOAuthApplication") + } + + var r0 *model.OAuthAppConfigProcessedDTO + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string) (*model.OAuthAppConfigProcessedDTO, *serviceerror.ServiceError)); ok { + return returnFunc(clientID) + } + if returnFunc, ok := ret.Get(0).(func(string) *model.OAuthAppConfigProcessedDTO); ok { + r0 = returnFunc(clientID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.OAuthAppConfigProcessedDTO) + } + } + if returnFunc, ok := ret.Get(1).(func(string) *serviceerror.ServiceError); ok { + r1 = returnFunc(clientID) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// ApplicationServiceInterfaceMock_GetOAuthApplication_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetOAuthApplication' +type ApplicationServiceInterfaceMock_GetOAuthApplication_Call struct { + *mock.Call +} + +// GetOAuthApplication is a helper method to define mock.On call +// - clientID string +func (_e *ApplicationServiceInterfaceMock_Expecter) GetOAuthApplication(clientID interface{}) *ApplicationServiceInterfaceMock_GetOAuthApplication_Call { + return &ApplicationServiceInterfaceMock_GetOAuthApplication_Call{Call: _e.mock.On("GetOAuthApplication", clientID)} +} + +func (_c *ApplicationServiceInterfaceMock_GetOAuthApplication_Call) Run(run func(clientID string)) *ApplicationServiceInterfaceMock_GetOAuthApplication_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetOAuthApplication_Call) Return(oAuthAppConfigProcessedDTO *model.OAuthAppConfigProcessedDTO, serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_GetOAuthApplication_Call { + _c.Call.Return(oAuthAppConfigProcessedDTO, serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_GetOAuthApplication_Call) RunAndReturn(run func(clientID string) (*model.OAuthAppConfigProcessedDTO, *serviceerror.ServiceError)) *ApplicationServiceInterfaceMock_GetOAuthApplication_Call { + _c.Call.Return(run) + return _c +} + +// UpdateApplication provides a mock function for the type ApplicationServiceInterfaceMock +func (_mock *ApplicationServiceInterfaceMock) UpdateApplication(appID string, app *model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError) { + ret := _mock.Called(appID, app) + + if len(ret) == 0 { + panic("no return value specified for UpdateApplication") + } + + var r0 *model.ApplicationDTO + var r1 *serviceerror.ServiceError + if returnFunc, ok := ret.Get(0).(func(string, *model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError)); ok { + return returnFunc(appID, app) + } + if returnFunc, ok := ret.Get(0).(func(string, *model.ApplicationDTO) *model.ApplicationDTO); ok { + r0 = returnFunc(appID, app) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*model.ApplicationDTO) + } + } + if returnFunc, ok := ret.Get(1).(func(string, *model.ApplicationDTO) *serviceerror.ServiceError); ok { + r1 = returnFunc(appID, app) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*serviceerror.ServiceError) + } + } + return r0, r1 +} + +// ApplicationServiceInterfaceMock_UpdateApplication_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateApplication' +type ApplicationServiceInterfaceMock_UpdateApplication_Call struct { + *mock.Call +} + +// UpdateApplication is a helper method to define mock.On call +// - appID string +// - app *model.ApplicationDTO +func (_e *ApplicationServiceInterfaceMock_Expecter) UpdateApplication(appID interface{}, app interface{}) *ApplicationServiceInterfaceMock_UpdateApplication_Call { + return &ApplicationServiceInterfaceMock_UpdateApplication_Call{Call: _e.mock.On("UpdateApplication", appID, app)} +} + +func (_c *ApplicationServiceInterfaceMock_UpdateApplication_Call) Run(run func(appID string, app *model.ApplicationDTO)) *ApplicationServiceInterfaceMock_UpdateApplication_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 *model.ApplicationDTO + if args[1] != nil { + arg1 = args[1].(*model.ApplicationDTO) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_UpdateApplication_Call) Return(applicationDTO *model.ApplicationDTO, serviceError *serviceerror.ServiceError) *ApplicationServiceInterfaceMock_UpdateApplication_Call { + _c.Call.Return(applicationDTO, serviceError) + return _c +} + +func (_c *ApplicationServiceInterfaceMock_UpdateApplication_Call) RunAndReturn(run func(appID string, app *model.ApplicationDTO) (*model.ApplicationDTO, *serviceerror.ServiceError)) *ApplicationServiceInterfaceMock_UpdateApplication_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/tests/mocks/oauth/jwksmock/JWKSServiceInterface_mock.go b/backend/tests/mocks/oauth/jwksmock/JWKSServiceInterface_mock.go index 9ac4ca5e..b254422e 100644 --- a/backend/tests/mocks/oauth/jwksmock/JWKSServiceInterface_mock.go +++ b/backend/tests/mocks/oauth/jwksmock/JWKSServiceInterface_mock.go @@ -5,7 +5,7 @@ package jwksmock import ( - "github.com/asgardeo/thunder/internal/oauth/jwks/model" + "github.com/asgardeo/thunder/internal/oauth/jwks" "github.com/asgardeo/thunder/internal/system/error/serviceerror" mock "github.com/stretchr/testify/mock" ) @@ -38,23 +38,23 @@ func (_m *JWKSServiceInterfaceMock) EXPECT() *JWKSServiceInterfaceMock_Expecter } // GetJWKS provides a mock function for the type JWKSServiceInterfaceMock -func (_mock *JWKSServiceInterfaceMock) GetJWKS() (*model.JWKSResponse, *serviceerror.ServiceError) { +func (_mock *JWKSServiceInterfaceMock) GetJWKS() (*jwks.JWKSResponse, *serviceerror.ServiceError) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for GetJWKS") } - var r0 *model.JWKSResponse + var r0 *jwks.JWKSResponse var r1 *serviceerror.ServiceError - if returnFunc, ok := ret.Get(0).(func() (*model.JWKSResponse, *serviceerror.ServiceError)); ok { + if returnFunc, ok := ret.Get(0).(func() (*jwks.JWKSResponse, *serviceerror.ServiceError)); ok { return returnFunc() } - if returnFunc, ok := ret.Get(0).(func() *model.JWKSResponse); ok { + if returnFunc, ok := ret.Get(0).(func() *jwks.JWKSResponse); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*model.JWKSResponse) + r0 = ret.Get(0).(*jwks.JWKSResponse) } } if returnFunc, ok := ret.Get(1).(func() *serviceerror.ServiceError); ok { @@ -84,12 +84,12 @@ func (_c *JWKSServiceInterfaceMock_GetJWKS_Call) Run(run func()) *JWKSServiceInt return _c } -func (_c *JWKSServiceInterfaceMock_GetJWKS_Call) Return(jWKSResponse *model.JWKSResponse, serviceError *serviceerror.ServiceError) *JWKSServiceInterfaceMock_GetJWKS_Call { +func (_c *JWKSServiceInterfaceMock_GetJWKS_Call) Return(jWKSResponse *jwks.JWKSResponse, serviceError *serviceerror.ServiceError) *JWKSServiceInterfaceMock_GetJWKS_Call { _c.Call.Return(jWKSResponse, serviceError) return _c } -func (_c *JWKSServiceInterfaceMock_GetJWKS_Call) RunAndReturn(run func() (*model.JWKSResponse, *serviceerror.ServiceError)) *JWKSServiceInterfaceMock_GetJWKS_Call { +func (_c *JWKSServiceInterfaceMock_GetJWKS_Call) RunAndReturn(run func() (*jwks.JWKSResponse, *serviceerror.ServiceError)) *JWKSServiceInterfaceMock_GetJWKS_Call { _c.Call.Return(run) return _c } diff --git a/backend/tests/mocks/oauth/oauth2/authz/storemock/AuthorizationCodeStoreInterface_mock.go b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationCodeStoreInterface_mock.go similarity index 83% rename from backend/tests/mocks/oauth/oauth2/authz/storemock/AuthorizationCodeStoreInterface_mock.go rename to backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationCodeStoreInterface_mock.go index feb2ff36..183fc6fe 100644 --- a/backend/tests/mocks/oauth/oauth2/authz/storemock/AuthorizationCodeStoreInterface_mock.go +++ b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationCodeStoreInterface_mock.go @@ -2,10 +2,10 @@ // github.com/vektra/mockery // template: testify -package storemock +package authzmock import ( - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" mock "github.com/stretchr/testify/mock" ) @@ -37,7 +37,7 @@ func (_m *AuthorizationCodeStoreInterfaceMock) EXPECT() *AuthorizationCodeStoreI } // DeactivateAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock -func (_mock *AuthorizationCodeStoreInterfaceMock) DeactivateAuthorizationCode(authzCode model.AuthorizationCode) error { +func (_mock *AuthorizationCodeStoreInterfaceMock) DeactivateAuthorizationCode(authzCode authz.AuthorizationCode) error { ret := _mock.Called(authzCode) if len(ret) == 0 { @@ -45,7 +45,7 @@ func (_mock *AuthorizationCodeStoreInterfaceMock) DeactivateAuthorizationCode(au } var r0 error - if returnFunc, ok := ret.Get(0).(func(model.AuthorizationCode) error); ok { + if returnFunc, ok := ret.Get(0).(func(authz.AuthorizationCode) error); ok { r0 = returnFunc(authzCode) } else { r0 = ret.Error(0) @@ -59,16 +59,16 @@ type AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call struct } // DeactivateAuthorizationCode is a helper method to define mock.On call -// - authzCode model.AuthorizationCode +// - authzCode authz.AuthorizationCode func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) DeactivateAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { return &AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call{Call: _e.mock.On("DeactivateAuthorizationCode", authzCode)} } -func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) Run(run func(authzCode model.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) Run(run func(authzCode authz.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 model.AuthorizationCode + var arg0 authz.AuthorizationCode if args[0] != nil { - arg0 = args[0].(model.AuthorizationCode) + arg0 = args[0].(authz.AuthorizationCode) } run( arg0, @@ -82,13 +82,13 @@ func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) RunAndReturn(run func(authzCode model.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call) RunAndReturn(run func(authzCode authz.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_DeactivateAuthorizationCode_Call { _c.Call.Return(run) return _c } // ExpireAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock -func (_mock *AuthorizationCodeStoreInterfaceMock) ExpireAuthorizationCode(authzCode model.AuthorizationCode) error { +func (_mock *AuthorizationCodeStoreInterfaceMock) ExpireAuthorizationCode(authzCode authz.AuthorizationCode) error { ret := _mock.Called(authzCode) if len(ret) == 0 { @@ -96,7 +96,7 @@ func (_mock *AuthorizationCodeStoreInterfaceMock) ExpireAuthorizationCode(authzC } var r0 error - if returnFunc, ok := ret.Get(0).(func(model.AuthorizationCode) error); ok { + if returnFunc, ok := ret.Get(0).(func(authz.AuthorizationCode) error); ok { r0 = returnFunc(authzCode) } else { r0 = ret.Error(0) @@ -110,16 +110,16 @@ type AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call struct { } // ExpireAuthorizationCode is a helper method to define mock.On call -// - authzCode model.AuthorizationCode +// - authzCode authz.AuthorizationCode func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) ExpireAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { return &AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call{Call: _e.mock.On("ExpireAuthorizationCode", authzCode)} } -func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) Run(run func(authzCode model.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) Run(run func(authzCode authz.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 model.AuthorizationCode + var arg0 authz.AuthorizationCode if args[0] != nil { - arg0 = args[0].(model.AuthorizationCode) + arg0 = args[0].(authz.AuthorizationCode) } run( arg0, @@ -133,28 +133,28 @@ func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) Retu return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) RunAndReturn(run func(authzCode model.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call) RunAndReturn(run func(authzCode authz.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_ExpireAuthorizationCode_Call { _c.Call.Return(run) return _c } // GetAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock -func (_mock *AuthorizationCodeStoreInterfaceMock) GetAuthorizationCode(clientID string, authCode string) (model.AuthorizationCode, error) { +func (_mock *AuthorizationCodeStoreInterfaceMock) GetAuthorizationCode(clientID string, authCode string) (authz.AuthorizationCode, error) { ret := _mock.Called(clientID, authCode) if len(ret) == 0 { panic("no return value specified for GetAuthorizationCode") } - var r0 model.AuthorizationCode + var r0 authz.AuthorizationCode var r1 error - if returnFunc, ok := ret.Get(0).(func(string, string) (model.AuthorizationCode, error)); ok { + if returnFunc, ok := ret.Get(0).(func(string, string) (authz.AuthorizationCode, error)); ok { return returnFunc(clientID, authCode) } - if returnFunc, ok := ret.Get(0).(func(string, string) model.AuthorizationCode); ok { + if returnFunc, ok := ret.Get(0).(func(string, string) authz.AuthorizationCode); ok { r0 = returnFunc(clientID, authCode) } else { - r0 = ret.Get(0).(model.AuthorizationCode) + r0 = ret.Get(0).(authz.AuthorizationCode) } if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { r1 = returnFunc(clientID, authCode) @@ -194,18 +194,18 @@ func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) Run(run return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) Return(authorizationCode model.AuthorizationCode, err error) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) Return(authorizationCode authz.AuthorizationCode, err error) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { _c.Call.Return(authorizationCode, err) return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) RunAndReturn(run func(clientID string, authCode string) (model.AuthorizationCode, error)) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call) RunAndReturn(run func(clientID string, authCode string) (authz.AuthorizationCode, error)) *AuthorizationCodeStoreInterfaceMock_GetAuthorizationCode_Call { _c.Call.Return(run) return _c } // InsertAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock -func (_mock *AuthorizationCodeStoreInterfaceMock) InsertAuthorizationCode(authzCode model.AuthorizationCode) error { +func (_mock *AuthorizationCodeStoreInterfaceMock) InsertAuthorizationCode(authzCode authz.AuthorizationCode) error { ret := _mock.Called(authzCode) if len(ret) == 0 { @@ -213,7 +213,7 @@ func (_mock *AuthorizationCodeStoreInterfaceMock) InsertAuthorizationCode(authzC } var r0 error - if returnFunc, ok := ret.Get(0).(func(model.AuthorizationCode) error); ok { + if returnFunc, ok := ret.Get(0).(func(authz.AuthorizationCode) error); ok { r0 = returnFunc(authzCode) } else { r0 = ret.Error(0) @@ -227,16 +227,16 @@ type AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call struct { } // InsertAuthorizationCode is a helper method to define mock.On call -// - authzCode model.AuthorizationCode +// - authzCode authz.AuthorizationCode func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) InsertAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { return &AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call{Call: _e.mock.On("InsertAuthorizationCode", authzCode)} } -func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) Run(run func(authzCode model.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) Run(run func(authzCode authz.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 model.AuthorizationCode + var arg0 authz.AuthorizationCode if args[0] != nil { - arg0 = args[0].(model.AuthorizationCode) + arg0 = args[0].(authz.AuthorizationCode) } run( arg0, @@ -250,13 +250,13 @@ func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) Retu return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) RunAndReturn(run func(authzCode model.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call) RunAndReturn(run func(authzCode authz.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_InsertAuthorizationCode_Call { _c.Call.Return(run) return _c } // RevokeAuthorizationCode provides a mock function for the type AuthorizationCodeStoreInterfaceMock -func (_mock *AuthorizationCodeStoreInterfaceMock) RevokeAuthorizationCode(authzCode model.AuthorizationCode) error { +func (_mock *AuthorizationCodeStoreInterfaceMock) RevokeAuthorizationCode(authzCode authz.AuthorizationCode) error { ret := _mock.Called(authzCode) if len(ret) == 0 { @@ -264,7 +264,7 @@ func (_mock *AuthorizationCodeStoreInterfaceMock) RevokeAuthorizationCode(authzC } var r0 error - if returnFunc, ok := ret.Get(0).(func(model.AuthorizationCode) error); ok { + if returnFunc, ok := ret.Get(0).(func(authz.AuthorizationCode) error); ok { r0 = returnFunc(authzCode) } else { r0 = ret.Error(0) @@ -278,16 +278,16 @@ type AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call struct { } // RevokeAuthorizationCode is a helper method to define mock.On call -// - authzCode model.AuthorizationCode +// - authzCode authz.AuthorizationCode func (_e *AuthorizationCodeStoreInterfaceMock_Expecter) RevokeAuthorizationCode(authzCode interface{}) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { return &AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call{Call: _e.mock.On("RevokeAuthorizationCode", authzCode)} } -func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) Run(run func(authzCode model.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) Run(run func(authzCode authz.AuthorizationCode)) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 model.AuthorizationCode + var arg0 authz.AuthorizationCode if args[0] != nil { - arg0 = args[0].(model.AuthorizationCode) + arg0 = args[0].(authz.AuthorizationCode) } run( arg0, @@ -301,7 +301,7 @@ func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) Retu return _c } -func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) RunAndReturn(run func(authzCode model.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { +func (_c *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call) RunAndReturn(run func(authzCode authz.AuthorizationCode) error) *AuthorizationCodeStoreInterfaceMock_RevokeAuthorizationCode_Call { _c.Call.Return(run) return _c } diff --git a/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationValidatorInterface_mock.go b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationValidatorInterface_mock.go index 389121be..0bf765c0 100644 --- a/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationValidatorInterface_mock.go +++ b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizationValidatorInterface_mock.go @@ -5,8 +5,8 @@ package authzmock import ( - model0 "github.com/asgardeo/thunder/internal/application/model" - "github.com/asgardeo/thunder/internal/oauth/oauth2/authz/model" + "github.com/asgardeo/thunder/internal/application/model" + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" mock "github.com/stretchr/testify/mock" ) @@ -38,7 +38,7 @@ func (_m *AuthorizationValidatorInterfaceMock) EXPECT() *AuthorizationValidatorI } // validateInitialAuthorizationRequest provides a mock function for the type AuthorizationValidatorInterfaceMock -func (_mock *AuthorizationValidatorInterfaceMock) validateInitialAuthorizationRequest(msg *model.OAuthMessage, oauthApp *model0.OAuthAppConfigProcessedDTO) (bool, string, string) { +func (_mock *AuthorizationValidatorInterfaceMock) validateInitialAuthorizationRequest(msg *authz.OAuthMessage, oauthApp *model.OAuthAppConfigProcessedDTO) (bool, string, string) { ret := _mock.Called(msg, oauthApp) if len(ret) == 0 { @@ -48,20 +48,20 @@ func (_mock *AuthorizationValidatorInterfaceMock) validateInitialAuthorizationRe var r0 bool var r1 string var r2 string - if returnFunc, ok := ret.Get(0).(func(*model.OAuthMessage, *model0.OAuthAppConfigProcessedDTO) (bool, string, string)); ok { + if returnFunc, ok := ret.Get(0).(func(*authz.OAuthMessage, *model.OAuthAppConfigProcessedDTO) (bool, string, string)); ok { return returnFunc(msg, oauthApp) } - if returnFunc, ok := ret.Get(0).(func(*model.OAuthMessage, *model0.OAuthAppConfigProcessedDTO) bool); ok { + if returnFunc, ok := ret.Get(0).(func(*authz.OAuthMessage, *model.OAuthAppConfigProcessedDTO) bool); ok { r0 = returnFunc(msg, oauthApp) } else { r0 = ret.Get(0).(bool) } - if returnFunc, ok := ret.Get(1).(func(*model.OAuthMessage, *model0.OAuthAppConfigProcessedDTO) string); ok { + if returnFunc, ok := ret.Get(1).(func(*authz.OAuthMessage, *model.OAuthAppConfigProcessedDTO) string); ok { r1 = returnFunc(msg, oauthApp) } else { r1 = ret.Get(1).(string) } - if returnFunc, ok := ret.Get(2).(func(*model.OAuthMessage, *model0.OAuthAppConfigProcessedDTO) string); ok { + if returnFunc, ok := ret.Get(2).(func(*authz.OAuthMessage, *model.OAuthAppConfigProcessedDTO) string); ok { r2 = returnFunc(msg, oauthApp) } else { r2 = ret.Get(2).(string) @@ -75,21 +75,21 @@ type AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Cal } // validateInitialAuthorizationRequest is a helper method to define mock.On call -// - msg *model.OAuthMessage -// - oauthApp *model0.OAuthAppConfigProcessedDTO +// - msg *authz.OAuthMessage +// - oauthApp *model.OAuthAppConfigProcessedDTO func (_e *AuthorizationValidatorInterfaceMock_Expecter) validateInitialAuthorizationRequest(msg interface{}, oauthApp interface{}) *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call { return &AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call{Call: _e.mock.On("validateInitialAuthorizationRequest", msg, oauthApp)} } -func (_c *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call) Run(run func(msg *model.OAuthMessage, oauthApp *model0.OAuthAppConfigProcessedDTO)) *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call { +func (_c *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call) Run(run func(msg *authz.OAuthMessage, oauthApp *model.OAuthAppConfigProcessedDTO)) *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 *model.OAuthMessage + var arg0 *authz.OAuthMessage if args[0] != nil { - arg0 = args[0].(*model.OAuthMessage) + arg0 = args[0].(*authz.OAuthMessage) } - var arg1 *model0.OAuthAppConfigProcessedDTO + var arg1 *model.OAuthAppConfigProcessedDTO if args[1] != nil { - arg1 = args[1].(*model0.OAuthAppConfigProcessedDTO) + arg1 = args[1].(*model.OAuthAppConfigProcessedDTO) } run( arg0, @@ -104,7 +104,7 @@ func (_c *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationReques return _c } -func (_c *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call) RunAndReturn(run func(msg *model.OAuthMessage, oauthApp *model0.OAuthAppConfigProcessedDTO) (bool, string, string)) *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call { +func (_c *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call) RunAndReturn(run func(msg *authz.OAuthMessage, oauthApp *model.OAuthAppConfigProcessedDTO) (bool, string, string)) *AuthorizationValidatorInterfaceMock_validateInitialAuthorizationRequest_Call { _c.Call.Return(run) return _c } diff --git a/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizeServiceInterface_mock.go b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizeServiceInterface_mock.go new file mode 100644 index 00000000..aa06f37c --- /dev/null +++ b/backend/tests/mocks/oauth/oauth2/authzmock/AuthorizeServiceInterface_mock.go @@ -0,0 +1,105 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authzmock + +import ( + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" + mock "github.com/stretchr/testify/mock" +) + +// NewAuthorizeServiceInterfaceMock creates a new instance of AuthorizeServiceInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAuthorizeServiceInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *AuthorizeServiceInterfaceMock { + mock := &AuthorizeServiceInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// AuthorizeServiceInterfaceMock is an autogenerated mock type for the AuthorizeServiceInterface type +type AuthorizeServiceInterfaceMock struct { + mock.Mock +} + +type AuthorizeServiceInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *AuthorizeServiceInterfaceMock) EXPECT() *AuthorizeServiceInterfaceMock_Expecter { + return &AuthorizeServiceInterfaceMock_Expecter{mock: &_m.Mock} +} + +// GetAuthorizationCodeDetails provides a mock function for the type AuthorizeServiceInterfaceMock +func (_mock *AuthorizeServiceInterfaceMock) GetAuthorizationCodeDetails(clientID string, code string) (*authz.AuthorizationCode, error) { + ret := _mock.Called(clientID, code) + + if len(ret) == 0 { + panic("no return value specified for GetAuthorizationCodeDetails") + } + + var r0 *authz.AuthorizationCode + var r1 error + if returnFunc, ok := ret.Get(0).(func(string, string) (*authz.AuthorizationCode, error)); ok { + return returnFunc(clientID, code) + } + if returnFunc, ok := ret.Get(0).(func(string, string) *authz.AuthorizationCode); ok { + r0 = returnFunc(clientID, code) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*authz.AuthorizationCode) + } + } + if returnFunc, ok := ret.Get(1).(func(string, string) error); ok { + r1 = returnFunc(clientID, code) + } else { + r1 = ret.Error(1) + } + return r0, r1 +} + +// AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetAuthorizationCodeDetails' +type AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call struct { + *mock.Call +} + +// GetAuthorizationCodeDetails is a helper method to define mock.On call +// - clientID string +// - code string +func (_e *AuthorizeServiceInterfaceMock_Expecter) GetAuthorizationCodeDetails(clientID interface{}, code interface{}) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + return &AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call{Call: _e.mock.On("GetAuthorizationCodeDetails", clientID, code)} +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) Run(run func(clientID string, code string)) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + var arg1 string + if args[1] != nil { + arg1 = args[1].(string) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) Return(authorizationCode *authz.AuthorizationCode, err error) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Return(authorizationCode, err) + return _c +} + +func (_c *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call) RunAndReturn(run func(clientID string, code string) (*authz.AuthorizationCode, error)) *AuthorizeServiceInterfaceMock_GetAuthorizationCodeDetails_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/tests/mocks/oauth/oauth2/authzmock/sessionDataStoreInterface_mock.go b/backend/tests/mocks/oauth/oauth2/authzmock/sessionDataStoreInterface_mock.go new file mode 100644 index 00000000..4bb2d681 --- /dev/null +++ b/backend/tests/mocks/oauth/oauth2/authzmock/sessionDataStoreInterface_mock.go @@ -0,0 +1,221 @@ +// Code generated by mockery; DO NOT EDIT. +// github.com/vektra/mockery +// template: testify + +package authzmock + +import ( + "github.com/asgardeo/thunder/internal/oauth/oauth2/authz" + mock "github.com/stretchr/testify/mock" +) + +// newSessionDataStoreInterfaceMock creates a new instance of sessionDataStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func newSessionDataStoreInterfaceMock(t interface { + mock.TestingT + Cleanup(func()) +}) *sessionDataStoreInterfaceMock { + mock := &sessionDataStoreInterfaceMock{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// sessionDataStoreInterfaceMock is an autogenerated mock type for the sessionDataStoreInterface type +type sessionDataStoreInterfaceMock struct { + mock.Mock +} + +type sessionDataStoreInterfaceMock_Expecter struct { + mock *mock.Mock +} + +func (_m *sessionDataStoreInterfaceMock) EXPECT() *sessionDataStoreInterfaceMock_Expecter { + return &sessionDataStoreInterfaceMock_Expecter{mock: &_m.Mock} +} + +// AddSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) AddSession(value authz.SessionData) string { + ret := _mock.Called(value) + + if len(ret) == 0 { + panic("no return value specified for AddSession") + } + + var r0 string + if returnFunc, ok := ret.Get(0).(func(authz.SessionData) string); ok { + r0 = returnFunc(value) + } else { + r0 = ret.Get(0).(string) + } + return r0 +} + +// sessionDataStoreInterfaceMock_AddSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddSession' +type sessionDataStoreInterfaceMock_AddSession_Call struct { + *mock.Call +} + +// AddSession is a helper method to define mock.On call +// - value authz.SessionData +func (_e *sessionDataStoreInterfaceMock_Expecter) AddSession(value interface{}) *sessionDataStoreInterfaceMock_AddSession_Call { + return &sessionDataStoreInterfaceMock_AddSession_Call{Call: _e.mock.On("AddSession", value)} +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) Run(run func(value authz.SessionData)) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 authz.SessionData + if args[0] != nil { + arg0 = args[0].(authz.SessionData) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) Return(s string) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Return(s) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_AddSession_Call) RunAndReturn(run func(value authz.SessionData) string) *sessionDataStoreInterfaceMock_AddSession_Call { + _c.Call.Return(run) + return _c +} + +// ClearSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) ClearSession(key string) { + _mock.Called(key) + return +} + +// sessionDataStoreInterfaceMock_ClearSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSession' +type sessionDataStoreInterfaceMock_ClearSession_Call struct { + *mock.Call +} + +// ClearSession is a helper method to define mock.On call +// - key string +func (_e *sessionDataStoreInterfaceMock_Expecter) ClearSession(key interface{}) *sessionDataStoreInterfaceMock_ClearSession_Call { + return &sessionDataStoreInterfaceMock_ClearSession_Call{Call: _e.mock.On("ClearSession", key)} +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) Run(run func(key string)) *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) Return() *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Call.Return() + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSession_Call) RunAndReturn(run func(key string)) *sessionDataStoreInterfaceMock_ClearSession_Call { + _c.Run(run) + return _c +} + +// ClearSessionStore provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) ClearSessionStore() { + _mock.Called() + return +} + +// sessionDataStoreInterfaceMock_ClearSessionStore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSessionStore' +type sessionDataStoreInterfaceMock_ClearSessionStore_Call struct { + *mock.Call +} + +// ClearSessionStore is a helper method to define mock.On call +func (_e *sessionDataStoreInterfaceMock_Expecter) ClearSessionStore() *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + return &sessionDataStoreInterfaceMock_ClearSessionStore_Call{Call: _e.mock.On("ClearSessionStore")} +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) Run(run func()) *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) Return() *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Call.Return() + return _c +} + +func (_c *sessionDataStoreInterfaceMock_ClearSessionStore_Call) RunAndReturn(run func()) *sessionDataStoreInterfaceMock_ClearSessionStore_Call { + _c.Run(run) + return _c +} + +// GetSession provides a mock function for the type sessionDataStoreInterfaceMock +func (_mock *sessionDataStoreInterfaceMock) GetSession(key string) (bool, authz.SessionData) { + ret := _mock.Called(key) + + if len(ret) == 0 { + panic("no return value specified for GetSession") + } + + var r0 bool + var r1 authz.SessionData + if returnFunc, ok := ret.Get(0).(func(string) (bool, authz.SessionData)); ok { + return returnFunc(key) + } + if returnFunc, ok := ret.Get(0).(func(string) bool); ok { + r0 = returnFunc(key) + } else { + r0 = ret.Get(0).(bool) + } + if returnFunc, ok := ret.Get(1).(func(string) authz.SessionData); ok { + r1 = returnFunc(key) + } else { + r1 = ret.Get(1).(authz.SessionData) + } + return r0, r1 +} + +// sessionDataStoreInterfaceMock_GetSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSession' +type sessionDataStoreInterfaceMock_GetSession_Call struct { + *mock.Call +} + +// GetSession is a helper method to define mock.On call +// - key string +func (_e *sessionDataStoreInterfaceMock_Expecter) GetSession(key interface{}) *sessionDataStoreInterfaceMock_GetSession_Call { + return &sessionDataStoreInterfaceMock_GetSession_Call{Call: _e.mock.On("GetSession", key)} +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) Run(run func(key string)) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 string + if args[0] != nil { + arg0 = args[0].(string) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) Return(b bool, sessionData authz.SessionData) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Return(b, sessionData) + return _c +} + +func (_c *sessionDataStoreInterfaceMock_GetSession_Call) RunAndReturn(run func(key string) (bool, authz.SessionData)) *sessionDataStoreInterfaceMock_GetSession_Call { + _c.Call.Return(run) + return _c +} diff --git a/backend/tests/mocks/oauth/scope/providermock/ScopeValidatorProviderInterface_mock.go b/backend/tests/mocks/oauth/scope/providermock/ScopeValidatorProviderInterface_mock.go deleted file mode 100644 index 30acdebf..00000000 --- a/backend/tests/mocks/oauth/scope/providermock/ScopeValidatorProviderInterface_mock.go +++ /dev/null @@ -1,83 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package providermock - -import ( - "github.com/asgardeo/thunder/internal/oauth/scope/validator" - mock "github.com/stretchr/testify/mock" -) - -// NewScopeValidatorProviderInterfaceMock creates a new instance of ScopeValidatorProviderInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewScopeValidatorProviderInterfaceMock(t interface { - mock.TestingT - Cleanup(func()) -}) *ScopeValidatorProviderInterfaceMock { - mock := &ScopeValidatorProviderInterfaceMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// ScopeValidatorProviderInterfaceMock is an autogenerated mock type for the ScopeValidatorProviderInterface type -type ScopeValidatorProviderInterfaceMock struct { - mock.Mock -} - -type ScopeValidatorProviderInterfaceMock_Expecter struct { - mock *mock.Mock -} - -func (_m *ScopeValidatorProviderInterfaceMock) EXPECT() *ScopeValidatorProviderInterfaceMock_Expecter { - return &ScopeValidatorProviderInterfaceMock_Expecter{mock: &_m.Mock} -} - -// GetScopeValidator provides a mock function for the type ScopeValidatorProviderInterfaceMock -func (_mock *ScopeValidatorProviderInterfaceMock) GetScopeValidator() validator.ScopeValidatorInterface { - ret := _mock.Called() - - if len(ret) == 0 { - panic("no return value specified for GetScopeValidator") - } - - var r0 validator.ScopeValidatorInterface - if returnFunc, ok := ret.Get(0).(func() validator.ScopeValidatorInterface); ok { - r0 = returnFunc() - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(validator.ScopeValidatorInterface) - } - } - return r0 -} - -// ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetScopeValidator' -type ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call struct { - *mock.Call -} - -// GetScopeValidator is a helper method to define mock.On call -func (_e *ScopeValidatorProviderInterfaceMock_Expecter) GetScopeValidator() *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call { - return &ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call{Call: _e.mock.On("GetScopeValidator")} -} - -func (_c *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call) Run(run func()) *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call) Return(scopeValidatorInterface validator.ScopeValidatorInterface) *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call { - _c.Call.Return(scopeValidatorInterface) - return _c -} - -func (_c *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call) RunAndReturn(run func() validator.ScopeValidatorInterface) *ScopeValidatorProviderInterfaceMock_GetScopeValidator_Call { - _c.Call.Return(run) - return _c -} diff --git a/backend/tests/mocks/oauth/scope/validatormock/ScopeValidatorInterface_mock.go b/backend/tests/mocks/oauth/scopemock/ScopeValidatorInterface_mock.go similarity index 85% rename from backend/tests/mocks/oauth/scope/validatormock/ScopeValidatorInterface_mock.go rename to backend/tests/mocks/oauth/scopemock/ScopeValidatorInterface_mock.go index 3c8d7355..13f755aa 100644 --- a/backend/tests/mocks/oauth/scope/validatormock/ScopeValidatorInterface_mock.go +++ b/backend/tests/mocks/oauth/scopemock/ScopeValidatorInterface_mock.go @@ -2,10 +2,10 @@ // github.com/vektra/mockery // template: testify -package validatormock +package scopemock import ( - "github.com/asgardeo/thunder/internal/oauth/scope/validator" + "github.com/asgardeo/thunder/internal/oauth/scope" mock "github.com/stretchr/testify/mock" ) @@ -37,7 +37,7 @@ func (_m *ScopeValidatorInterfaceMock) EXPECT() *ScopeValidatorInterfaceMock_Exp } // ValidateScopes provides a mock function for the type ScopeValidatorInterfaceMock -func (_mock *ScopeValidatorInterfaceMock) ValidateScopes(requestedScopes string, clientID string) (string, *validator.ScopeError) { +func (_mock *ScopeValidatorInterfaceMock) ValidateScopes(requestedScopes string, clientID string) (string, *scope.ScopeError) { ret := _mock.Called(requestedScopes, clientID) if len(ret) == 0 { @@ -45,8 +45,8 @@ func (_mock *ScopeValidatorInterfaceMock) ValidateScopes(requestedScopes string, } var r0 string - var r1 *validator.ScopeError - if returnFunc, ok := ret.Get(0).(func(string, string) (string, *validator.ScopeError)); ok { + var r1 *scope.ScopeError + if returnFunc, ok := ret.Get(0).(func(string, string) (string, *scope.ScopeError)); ok { return returnFunc(requestedScopes, clientID) } if returnFunc, ok := ret.Get(0).(func(string, string) string); ok { @@ -54,11 +54,11 @@ func (_mock *ScopeValidatorInterfaceMock) ValidateScopes(requestedScopes string, } else { r0 = ret.Get(0).(string) } - if returnFunc, ok := ret.Get(1).(func(string, string) *validator.ScopeError); ok { + if returnFunc, ok := ret.Get(1).(func(string, string) *scope.ScopeError); ok { r1 = returnFunc(requestedScopes, clientID) } else { if ret.Get(1) != nil { - r1 = ret.Get(1).(*validator.ScopeError) + r1 = ret.Get(1).(*scope.ScopeError) } } return r0, r1 @@ -94,12 +94,12 @@ func (_c *ScopeValidatorInterfaceMock_ValidateScopes_Call) Run(run func(requeste return _c } -func (_c *ScopeValidatorInterfaceMock_ValidateScopes_Call) Return(s string, scopeError *validator.ScopeError) *ScopeValidatorInterfaceMock_ValidateScopes_Call { +func (_c *ScopeValidatorInterfaceMock_ValidateScopes_Call) Return(s string, scopeError *scope.ScopeError) *ScopeValidatorInterfaceMock_ValidateScopes_Call { _c.Call.Return(s, scopeError) return _c } -func (_c *ScopeValidatorInterfaceMock_ValidateScopes_Call) RunAndReturn(run func(requestedScopes string, clientID string) (string, *validator.ScopeError)) *ScopeValidatorInterfaceMock_ValidateScopes_Call { +func (_c *ScopeValidatorInterfaceMock_ValidateScopes_Call) RunAndReturn(run func(requestedScopes string, clientID string) (string, *scope.ScopeError)) *ScopeValidatorInterfaceMock_ValidateScopes_Call { _c.Call.Return(run) return _c } diff --git a/backend/tests/mocks/oauth/session/storemock/SessionDataStoreInterface_mock.go b/backend/tests/mocks/oauth/session/storemock/SessionDataStoreInterface_mock.go deleted file mode 100644 index 21dc0eb3..00000000 --- a/backend/tests/mocks/oauth/session/storemock/SessionDataStoreInterface_mock.go +++ /dev/null @@ -1,216 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package storemock - -import ( - "github.com/asgardeo/thunder/internal/oauth/session/model" - mock "github.com/stretchr/testify/mock" -) - -// NewSessionDataStoreInterfaceMock creates a new instance of SessionDataStoreInterfaceMock. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewSessionDataStoreInterfaceMock(t interface { - mock.TestingT - Cleanup(func()) -}) *SessionDataStoreInterfaceMock { - mock := &SessionDataStoreInterfaceMock{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// SessionDataStoreInterfaceMock is an autogenerated mock type for the SessionDataStoreInterface type -type SessionDataStoreInterfaceMock struct { - mock.Mock -} - -type SessionDataStoreInterfaceMock_Expecter struct { - mock *mock.Mock -} - -func (_m *SessionDataStoreInterfaceMock) EXPECT() *SessionDataStoreInterfaceMock_Expecter { - return &SessionDataStoreInterfaceMock_Expecter{mock: &_m.Mock} -} - -// AddSession provides a mock function for the type SessionDataStoreInterfaceMock -func (_mock *SessionDataStoreInterfaceMock) AddSession(key string, value model.SessionData) { - _mock.Called(key, value) - return -} - -// SessionDataStoreInterfaceMock_AddSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'AddSession' -type SessionDataStoreInterfaceMock_AddSession_Call struct { - *mock.Call -} - -// AddSession is a helper method to define mock.On call -// - key string -// - value model.SessionData -func (_e *SessionDataStoreInterfaceMock_Expecter) AddSession(key interface{}, value interface{}) *SessionDataStoreInterfaceMock_AddSession_Call { - return &SessionDataStoreInterfaceMock_AddSession_Call{Call: _e.mock.On("AddSession", key, value)} -} - -func (_c *SessionDataStoreInterfaceMock_AddSession_Call) Run(run func(key string, value model.SessionData)) *SessionDataStoreInterfaceMock_AddSession_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 string - if args[0] != nil { - arg0 = args[0].(string) - } - var arg1 model.SessionData - if args[1] != nil { - arg1 = args[1].(model.SessionData) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *SessionDataStoreInterfaceMock_AddSession_Call) Return() *SessionDataStoreInterfaceMock_AddSession_Call { - _c.Call.Return() - return _c -} - -func (_c *SessionDataStoreInterfaceMock_AddSession_Call) RunAndReturn(run func(key string, value model.SessionData)) *SessionDataStoreInterfaceMock_AddSession_Call { - _c.Run(run) - return _c -} - -// ClearSession provides a mock function for the type SessionDataStoreInterfaceMock -func (_mock *SessionDataStoreInterfaceMock) ClearSession(key string) { - _mock.Called(key) - return -} - -// SessionDataStoreInterfaceMock_ClearSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSession' -type SessionDataStoreInterfaceMock_ClearSession_Call struct { - *mock.Call -} - -// ClearSession is a helper method to define mock.On call -// - key string -func (_e *SessionDataStoreInterfaceMock_Expecter) ClearSession(key interface{}) *SessionDataStoreInterfaceMock_ClearSession_Call { - return &SessionDataStoreInterfaceMock_ClearSession_Call{Call: _e.mock.On("ClearSession", key)} -} - -func (_c *SessionDataStoreInterfaceMock_ClearSession_Call) Run(run func(key string)) *SessionDataStoreInterfaceMock_ClearSession_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 string - if args[0] != nil { - arg0 = args[0].(string) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *SessionDataStoreInterfaceMock_ClearSession_Call) Return() *SessionDataStoreInterfaceMock_ClearSession_Call { - _c.Call.Return() - return _c -} - -func (_c *SessionDataStoreInterfaceMock_ClearSession_Call) RunAndReturn(run func(key string)) *SessionDataStoreInterfaceMock_ClearSession_Call { - _c.Run(run) - return _c -} - -// ClearSessionStore provides a mock function for the type SessionDataStoreInterfaceMock -func (_mock *SessionDataStoreInterfaceMock) ClearSessionStore() { - _mock.Called() - return -} - -// SessionDataStoreInterfaceMock_ClearSessionStore_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ClearSessionStore' -type SessionDataStoreInterfaceMock_ClearSessionStore_Call struct { - *mock.Call -} - -// ClearSessionStore is a helper method to define mock.On call -func (_e *SessionDataStoreInterfaceMock_Expecter) ClearSessionStore() *SessionDataStoreInterfaceMock_ClearSessionStore_Call { - return &SessionDataStoreInterfaceMock_ClearSessionStore_Call{Call: _e.mock.On("ClearSessionStore")} -} - -func (_c *SessionDataStoreInterfaceMock_ClearSessionStore_Call) Run(run func()) *SessionDataStoreInterfaceMock_ClearSessionStore_Call { - _c.Call.Run(func(args mock.Arguments) { - run() - }) - return _c -} - -func (_c *SessionDataStoreInterfaceMock_ClearSessionStore_Call) Return() *SessionDataStoreInterfaceMock_ClearSessionStore_Call { - _c.Call.Return() - return _c -} - -func (_c *SessionDataStoreInterfaceMock_ClearSessionStore_Call) RunAndReturn(run func()) *SessionDataStoreInterfaceMock_ClearSessionStore_Call { - _c.Run(run) - return _c -} - -// GetSession provides a mock function for the type SessionDataStoreInterfaceMock -func (_mock *SessionDataStoreInterfaceMock) GetSession(key string) (bool, model.SessionData) { - ret := _mock.Called(key) - - if len(ret) == 0 { - panic("no return value specified for GetSession") - } - - var r0 bool - var r1 model.SessionData - if returnFunc, ok := ret.Get(0).(func(string) (bool, model.SessionData)); ok { - return returnFunc(key) - } - if returnFunc, ok := ret.Get(0).(func(string) bool); ok { - r0 = returnFunc(key) - } else { - r0 = ret.Get(0).(bool) - } - if returnFunc, ok := ret.Get(1).(func(string) model.SessionData); ok { - r1 = returnFunc(key) - } else { - r1 = ret.Get(1).(model.SessionData) - } - return r0, r1 -} - -// SessionDataStoreInterfaceMock_GetSession_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetSession' -type SessionDataStoreInterfaceMock_GetSession_Call struct { - *mock.Call -} - -// GetSession is a helper method to define mock.On call -// - key string -func (_e *SessionDataStoreInterfaceMock_Expecter) GetSession(key interface{}) *SessionDataStoreInterfaceMock_GetSession_Call { - return &SessionDataStoreInterfaceMock_GetSession_Call{Call: _e.mock.On("GetSession", key)} -} - -func (_c *SessionDataStoreInterfaceMock_GetSession_Call) Run(run func(key string)) *SessionDataStoreInterfaceMock_GetSession_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 string - if args[0] != nil { - arg0 = args[0].(string) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *SessionDataStoreInterfaceMock_GetSession_Call) Return(b bool, sessionData model.SessionData) *SessionDataStoreInterfaceMock_GetSession_Call { - _c.Call.Return(b, sessionData) - return _c -} - -func (_c *SessionDataStoreInterfaceMock_GetSession_Call) RunAndReturn(run func(key string) (bool, model.SessionData)) *SessionDataStoreInterfaceMock_GetSession_Call { - _c.Call.Return(run) - return _c -} diff --git a/tests/integration/identity/oauth2/authz/authz_test.go b/tests/integration/identity/oauth2/authz/authz_test.go index 1b0dcf11..3d1c8fb4 100644 --- a/tests/integration/identity/oauth2/authz/authz_test.go +++ b/tests/integration/identity/oauth2/authz/authz_test.go @@ -337,28 +337,8 @@ func (ts *AuthzTestSuite) TestTokenRequestValidation() { }() // Get a valid authorization code first - resp, err := initiateAuthorizationFlow(clientID, redirectURI, "code", "openid", "token_test_state") - ts.NoError(err, "Failed to initiate authorization flow") - defer resp.Body.Close() - - ts.Equal(http.StatusFound, resp.StatusCode, "Expected redirect status") - location := resp.Header.Get("Location") - sessionDataKey, _, err := extractSessionData(location) - ts.NoError(err, "Failed to extract session data") - - // Execute authentication flow - flowStep, err := ExecuteAuthenticationFlow(ts.applicationID, map[string]string{ - "username": username, - "password": password, - }) - ts.NoError(err, "Failed to execute authentication flow") - ts.Equal("COMPLETE", flowStep.FlowStatus, "Flow should complete successfully") - - // Complete authorization - authzResponse, err := completeAuthorization(sessionDataKey, flowStep.Assertion) - ts.NoError(err, "Failed to complete authorization") - validAuthzCode, err := extractAuthorizationCode(authzResponse.RedirectURI) - ts.NoError(err, "Failed to extract authorization code") + validAuthzCode := initiateAuthorizeFlowAndRetrieveAuthzCode(ts, username, password) + anotherValidAuthzCode := initiateAuthorizeFlowAndRetrieveAuthzCode(ts, username, password) testCases := []struct { Name string @@ -451,21 +431,21 @@ func (ts *AuthzTestSuite) TestTokenRequestValidation() { ExpectedError: "invalid_client", }, { - Name: "Invalid Client Secret", + Name: "Mismatched Redirect URI", ClientID: clientID, - ClientSecret: "wrong_secret", - Code: validAuthzCode, - RedirectURI: redirectURI, + ClientSecret: clientSecret, + Code: anotherValidAuthzCode, + RedirectURI: "https://localhost:3001", GrantType: "authorization_code", - ExpectedStatus: http.StatusUnauthorized, - ExpectedError: "invalid_client", + ExpectedStatus: http.StatusBadRequest, + ExpectedError: "invalid_grant", }, { - Name: "Mismatched Redirect URI", + Name: "Used unsuccessful Authz Code", ClientID: clientID, ClientSecret: clientSecret, - Code: validAuthzCode, - RedirectURI: "https://localhost:3001", + Code: anotherValidAuthzCode, + RedirectURI: redirectURI, GrantType: "authorization_code", ExpectedStatus: http.StatusBadRequest, ExpectedError: "invalid_grant", @@ -490,6 +470,16 @@ func (ts *AuthzTestSuite) TestTokenRequestValidation() { ExpectedStatus: http.StatusOK, ExpectedError: "", }, + { + Name: "Used successful Authz Code", + ClientID: clientID, + ClientSecret: clientSecret, + Code: validAuthzCode, + RedirectURI: redirectURI, + GrantType: "authorization_code", + ExpectedStatus: http.StatusBadRequest, + ExpectedError: "invalid_grant", + }, } for _, tc := range testCases { @@ -530,6 +520,32 @@ func (ts *AuthzTestSuite) TestTokenRequestValidation() { } } +func initiateAuthorizeFlowAndRetrieveAuthzCode(ts *AuthzTestSuite, username string, password string) string { + resp, err := initiateAuthorizationFlow(clientID, redirectURI, "code", "openid", "token_test_state") + ts.NoError(err, "Failed to initiate authorization flow") + defer resp.Body.Close() + + ts.Equal(http.StatusFound, resp.StatusCode, "Expected redirect status") + location := resp.Header.Get("Location") + sessionDataKey, _, err := extractSessionData(location) + ts.NoError(err, "Failed to extract session data") + + // Execute authentication flow + flowStep, err := ExecuteAuthenticationFlow(ts.applicationID, map[string]string{ + "username": username, + "password": password, + }) + ts.NoError(err, "Failed to execute authentication flow") + ts.Equal("COMPLETE", flowStep.FlowStatus, "Flow should complete successfully") + + // Complete authorization + authzResponse, err := completeAuthorization(sessionDataKey, flowStep.Assertion) + ts.NoError(err, "Failed to complete authorization") + validAuthzCode, err := extractAuthorizationCode(authzResponse.RedirectURI) + ts.NoError(err, "Failed to extract authorization code") + return validAuthzCode +} + // TestRedirectURIValidation tests the redirect URI validation in OAuth2 flows func (ts *AuthzTestSuite) TestRedirectURIValidation() { testCases := []struct {