Skip to content

Commit

Permalink
fix custom auth header
Browse files Browse the repository at this point in the history
  • Loading branch information
tharindu1st committed Feb 18, 2025
1 parent a17623d commit 59a13ee
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 31 deletions.
47 changes: 42 additions & 5 deletions adapter/internal/discovery/xds/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,16 +273,49 @@ func GenerateEnvoyResoucesForGateway(gatewayName string) ([]types.Resource,
}
orgwizeJWTProviders := envoyGatewayConfig.jwtProviders
jwtRequirementMap := make(map[string]*jwt.JwtRequirement)
jwtProviderMap := make(map[string]*jwt.JwtProvider)
for organizationID, entityMap := range orgAPIMap {
for apiKey, envoyInternalAPI := range entityMap {
if _, exists := envoyInternalAPI.envoyLabels[gatewayName]; !exists {
// do nothing if the gateway is not found in the envoyInternalAPI
continue
}
if !envoyInternalAPI.adapterInternalAPI.GetDisableAuthentications() {
jwtRequirements := oasParser.GetJWTRequirements(envoyInternalAPI.adapterInternalAPI, orgwizeJWTProviders[organizationID])
if jwtRequirements != nil {
jwtRequirementMap[envoyInternalAPI.adapterInternalAPI.UUID] = jwtRequirements
var authorizationHeader *string
var sendTokenToUpstream *bool
if envoyInternalAPI.adapterInternalAPI != nil && envoyInternalAPI.adapterInternalAPI.GetResources() != nil {
for _, resource := range envoyInternalAPI.adapterInternalAPI.GetResources() {
if resource.GetMethod() != nil {
for _, method := range resource.GetMethod() {
if method.GetAuthentication() != nil {
if !method.GetAuthentication().Disabled && method.GetAuthentication().Oauth2 != nil {
authorizationHeader = &method.GetAuthentication().Oauth2.Header
sendTokenToUpstream = &method.GetAuthentication().Oauth2.SendTokenToUpstream
break
}
}
}
}
}
}
if authorizationHeader != nil || sendTokenToUpstream != nil {
jwtProviders, jwtclusters, jwtaddress, jwtRequirement, err := oasParser.GenerateAPILevelJWTPRoviders(orgwizeJWTProviders[organizationID], envoyInternalAPI.adapterInternalAPI, authorizationHeader, sendTokenToUpstream)
if err != nil {
logger.LoggerXds.ErrorC(logging.PrintError(logging.Error2301, logging.MAJOR, "Error generating JWT Providers: %v", err))
}
clusterArray = append(clusterArray, jwtclusters...)
endpointArray = append(endpointArray, jwtaddress...)
for key, value := range jwtProviders {
jwtProviderMap[key] = value
}
if jwtRequirement != nil {
jwtRequirementMap[envoyInternalAPI.adapterInternalAPI.UUID] = jwtRequirement
}
} else {
jwtRequirements := oasParser.GetJWTRequirements(envoyInternalAPI.adapterInternalAPI, orgwizeJWTProviders[organizationID])
if jwtRequirements != nil {
jwtRequirementMap[envoyInternalAPI.adapterInternalAPI.UUID] = jwtRequirements
}
}
}
vhost, err := ExtractVhostFromAPIIdentifier(apiKey)
Expand Down Expand Up @@ -320,13 +353,17 @@ func GenerateEnvoyResoucesForGateway(gatewayName string) ([]types.Resource,
readynessEndpoint := envoyconf.CreateReadyEndpoint()
vhostToRouteArrayMap[systemHost] = append(vhostToRouteArrayMap[systemHost], readynessEndpoint)
}
jwtProviders, jwtclusters, jwtaddress, err := oasParser.GenerateJWTPRoviders(orgwizeJWTProviders)
jwtProviders, jwtclusters, jwtaddress, err := oasParser.GenerateJWTPRoviderv3(orgwizeJWTProviders)
if err != nil {
logger.LoggerXds.ErrorC(logging.PrintError(logging.Error1100, logging.MAJOR, "Error generating JWT Providers: %v", err))
}
for key, value := range jwtProviders {
jwtProviderMap[key] = value
}

clusterArray = append(clusterArray, jwtclusters...)
endpointArray = append(endpointArray, jwtaddress...)
jwtFilter, err := oasParser.GetJWTFilter(jwtRequirementMap, jwtProviders)
jwtFilter, err := oasParser.GetJWTFilter(jwtRequirementMap, jwtProviderMap)
listeners := envoyGatewayConfig.listeners
if !config.ReadConfigs().Adapter.EnableGatewayClassController && len(listeners) < 1 {
return nil, nil, nil, nil, nil
Expand Down
5 changes: 5 additions & 0 deletions adapter/internal/logging/logging_constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const (
Error2248 = 2248
Error2249 = 2249
Error2250 = 2250
Error2301 = 2301
)

// Error Log RateLimiter callbacks(2300-2399) Config Constants
Expand Down Expand Up @@ -418,4 +419,8 @@ var Mapper = map[int]logging.ErrorDetails{
ErrorCode: Error2300,
Message: "Error in Stream request.",
},
Error2301: {
ErrorCode: Error2301,
Message: "Error while generating JWTProviders",
},
}
50 changes: 48 additions & 2 deletions adapter/internal/oasparser/config_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,12 @@ func GetJWTRequirements(adapterAPI *model.AdapterInternalAPI, jwtIssuers map[str
selectedIssuers = append(selectedIssuers, issuserName)
}
}
return GetAPILevelJWTRequirements(adapterAPI, selectedIssuers)

}

// GetAPILevelJWTRequirements returns the jwt requirements for the resource
func GetAPILevelJWTRequirements(adapterAPI *model.AdapterInternalAPI, selectedIssuers []string) *jwt.JwtRequirement {
if len(selectedIssuers) >= 1 {
return &jwt.JwtRequirement{
RequiresType: &jwt.JwtRequirement_RequiresAny{
Expand All @@ -487,8 +493,44 @@ func GetJWTRequirements(adapterAPI *model.AdapterInternalAPI, jwtIssuers map[str
return nil
}

// GenerateAPILevelJWTPRoviders generates the jwt provider for the resource
func GenerateAPILevelJWTPRoviders(jwtIssuers map[string]*v1alpha1.ResolvedJWTIssuer, adapterAPI *model.AdapterInternalAPI, authorizationHeader *string, sendTokenToUpStream *bool) (map[string]*jwt.JwtProvider, []*clusterv3.Cluster, []*corev3.Address, *jwt.JwtRequirement, error) {
jwtProviders := map[string]*jwt.JwtProvider{}
var clusters []*clusterv3.Cluster
var addresses []*corev3.Address
var selectedIssuers []string
for issuerMappingName, jwtIssuer := range jwtIssuers {
providerName := adapterAPI.UUID + "-" + issuerMappingName
var seleced bool
if contains(jwtIssuer.Environments, "*") {
selectedIssuers = append(selectedIssuers, providerName)
seleced = true
} else if contains(jwtIssuer.Environments, adapterAPI.GetEnvironment()) {
selectedIssuers = append(selectedIssuers, providerName)
seleced = true
}
if seleced {
provider, cluster, address, err := getjwtAuthFilters(jwtIssuer, providerName)
if err != nil {
return nil, nil, nil, nil, err
}
if authorizationHeader != nil {
provider.FromHeaders = []*jwt.JwtHeader{{Name: *authorizationHeader, ValuePrefix: "Bearer "}}
}
if sendTokenToUpStream != nil {
provider.Forward = *sendTokenToUpStream
}
jwtProviders[providerName] = provider
clusters = append(clusters, cluster...)
addresses = append(addresses, address...)
}
}
requirements := GetAPILevelJWTRequirements(adapterAPI, selectedIssuers)
return jwtProviders, clusters, addresses, requirements, nil
}

// GenerateJWTPRoviderv3 generates the jwt provider for the resource
func GenerateJWTPRoviders(jwtProviderMap map[string]map[string]*v1alpha1.ResolvedJWTIssuer) (map[string]*jwt.JwtProvider, []*clusterv3.Cluster, []*corev3.Address, error) {
func GenerateJWTPRoviderv3(jwtProviderMap map[string]map[string]*v1alpha1.ResolvedJWTIssuer) (map[string]*jwt.JwtProvider, []*clusterv3.Cluster, []*corev3.Address, error) {
jwtProviders := map[string]*jwt.JwtProvider{}
var clusters []*clusterv3.Cluster
var addresses []*corev3.Address
Expand Down Expand Up @@ -517,14 +559,18 @@ func contains(arr []string, str string) bool {
return false
}
func getjwtAuthFilters(tokenIssuer *v1alpha1.ResolvedJWTIssuer, issuerName string) (*jwt.JwtProvider, []*clusterv3.Cluster, []*corev3.Address, error) {
conf := config.ReadConfigs()

jwksClusters := make([]*clusterv3.Cluster, 0)
jwksAddresses := make([]*corev3.Address, 0)
jwtProvider := &jwt.JwtProvider{
Issuer: tokenIssuer.Issuer,
Forward: true,
FailedStatusInMetadata: tokenIssuer.Issuer + "-failed",
PayloadInMetadata: tokenIssuer.Issuer + "-payload",
}
if conf.Enforcer.Cache.Enabled {
jwtProvider.JwtCacheConfig = &jwt.JwtCacheConfig{JwtCacheSize: uint32(conf.Enforcer.Cache.MaximumSize)}
}
if tokenIssuer.SignatureValidation.JWKS != nil {
logger.LoggerOasparser.Infof("JWKS URL: %s", tokenIssuer.SignatureValidation.JWKS.URL)
jwksCluster, jwksAddress, err := getRemoteJWKSCluster(*tokenIssuer.SignatureValidation.JWKS, issuerName)
Expand Down
24 changes: 11 additions & 13 deletions gateway/enforcer/internal/authentication/jwt_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,17 @@ import (

// ValidateToken validates the JWT token.
func ValidateToken(rch *requestconfig.Holder, jwtTransformer *transformer.JWTTransformer, revokedJTIStore *datastore.RevokedJTIStore) *dto.ImmediateResponse {
if rch != nil {
if rch.ExternalProcessingEnvoyMetadata.JwtAuthenticationData != nil {
jwtValidationInfo := jwtTransformer.TransformJWTClaims(rch.MatchedAPI.OrganizationID, rch.ExternalProcessingEnvoyMetadata.JwtAuthenticationData)
if jwtValidationInfo != nil {
if revokedJTIStore != nil && revokedJTIStore.IsJTIRevoked(jwtValidationInfo.JTI) {
errorResponse := &dto.ErrorResponse{ErrorMessage: "Invalid Credentials", Code: "900901", ErrorDescription: "Make sure you have provided the correct security credentials"}
jsonData, _ := json.MarshalIndent(errorResponse, "", " ")
return &dto.ImmediateResponse{StatusCode: 401, Message: string(jsonData)}
}
if jwtValidationInfo.Valid {
rch.JWTValidationInfo = jwtValidationInfo
return nil
}
if rch != nil && rch.ExternalProcessingEnvoyMetadata != nil && rch.ExternalProcessingEnvoyMetadata.JwtAuthenticationData != nil {
jwtValidationInfo := jwtTransformer.TransformJWTClaims(rch.MatchedAPI.OrganizationID, rch.ExternalProcessingEnvoyMetadata.JwtAuthenticationData)
if jwtValidationInfo != nil {
if revokedJTIStore != nil && revokedJTIStore.IsJTIRevoked(jwtValidationInfo.JTI) {
errorResponse := &dto.ErrorResponse{ErrorMessage: "Invalid Credentials", Code: "900901", ErrorDescription: "Make sure you have provided the correct security credentials"}
jsonData, _ := json.MarshalIndent(errorResponse, "", " ")
return &dto.ImmediateResponse{StatusCode: 401, Message: string(jsonData)}
}
if jwtValidationInfo.Valid {
rch.JWTValidationInfo = jwtValidationInfo
return nil
}
}
}
Expand Down
21 changes: 10 additions & 11 deletions gateway/enforcer/internal/extproc/ext_proc.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,14 @@ func StartExternalProcessingServer(cfg *config.Server, apiStore *datastore.APISt
}

ratelimitHelper := ratelimit.NewAIRatelimitHelper(cfg)
envoy_service_proc_v3.RegisterExternalProcessorServer(server,
&ExternalProcessingServer{cfg.Logger,
apiStore,
subAppDatastore,
ratelimitHelper,
cfg,
jwtTransformer,
modelBasedRoundRobinTracker,
envoy_service_proc_v3.RegisterExternalProcessorServer(server,
&ExternalProcessingServer{cfg.Logger,
apiStore,
subAppDatastore,
ratelimitHelper,
cfg,
jwtTransformer,
modelBasedRoundRobinTracker,
revokedJTIStore})
listener, err := net.Listen("tcp", fmt.Sprintf(":%s", cfg.ExternalProcessingPort))
if err != nil {
Expand Down Expand Up @@ -984,8 +984,8 @@ func extractExternalProcessingMetadata(data *corev3.Metadata) (*dto.ExternalProc
if filterMatadata != nil {
externalProcessingEnvoyMetadata := &dto.ExternalProcessingEnvoyMetadata{}
jwtFilterdata := filterMatadata["envoy.filters.http.jwt_authn"]
authenticationData := &dto.JwtAuthenticationData{}
if jwtFilterdata != nil {
authenticationData := &dto.JwtAuthenticationData{}
for key, structValue := range jwtFilterdata.Fields {
if strings.HasSuffix(key, "-payload") {
sucessData := dto.JWTAuthenticationSuccessData{}
Expand Down Expand Up @@ -1029,7 +1029,6 @@ func extractExternalProcessingMetadata(data *corev3.Metadata) (*dto.ExternalProc
}
}
}

externalProcessingEnvoyMetadata.JwtAuthenticationData = authenticationData
}
if extProcMetadata, exists := filterMatadata[externalProessingMetadataContextKey]; exists {
Expand All @@ -1049,7 +1048,7 @@ func extractExternalProcessingMetadata(data *corev3.Metadata) (*dto.ExternalProc
}
return externalProcessingEnvoyMetadata, nil
}
return nil, fmt.Errorf("could not find the filter metadata")
return nil, nil
}

// ReadGzip decompresses a GZIP-compressed byte slice and returns the string output
Expand Down

0 comments on commit 59a13ee

Please sign in to comment.