Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ require (
github.com/yuin/gopher-lua v1.1.1
gitlab.com/gitlab-org/api/client-go v1.29.0
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.64.0
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0
go.opentelemetry.io/otel v1.39.0
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0
go.opentelemetry.io/otel/sdk v1.39.0
go.opentelemetry.io/otel/trace v1.39.0
golang.org/x/crypto v0.48.0
golang.org/x/net v0.50.0
golang.org/x/oauth2 v0.34.0
Expand Down Expand Up @@ -273,10 +275,8 @@ require (
github.com/xlab/treeprint v1.2.0 // indirect
go.mongodb.org/mongo-driver v1.17.6 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.59.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
go.opentelemetry.io/otel/metric v1.39.0 // indirect
go.opentelemetry.io/otel/trace v1.39.0 // indirect
go.opentelemetry.io/proto/otlp v1.9.0 // indirect
go.yaml.in/yaml/v2 v2.4.2 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
Expand Down
29 changes: 23 additions & 6 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ import (
log "github.com/sirupsen/logrus"
"github.com/soheilhy/cmux"
"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel"
otel_codes "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand Down Expand Up @@ -166,13 +169,17 @@ var (
enableGRPCTimeHistogram = true
)

// OpenTelemetry tracer for this package
var tracer trace.Tracer

func init() {
maxConcurrentLoginRequestsCount = env.ParseNumFromEnv(maxConcurrentLoginRequestsCountEnv, maxConcurrentLoginRequestsCount, 0, math.MaxInt32)
replicasCount = env.ParseNumFromEnv(replicasCountEnv, replicasCount, 0, math.MaxInt32)
if replicasCount > 0 {
maxConcurrentLoginRequestsCount = maxConcurrentLoginRequestsCount / replicasCount
}
enableGRPCTimeHistogram = env.ParseBoolFromEnv(common.EnvEnableGRPCTimeHistogramEnv, false)
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/server")
}

// ArgoCDServer is the API server for Argo CD
Expand Down Expand Up @@ -1164,8 +1171,8 @@ func (server *ArgoCDServer) newHTTPServer(ctx context.Context, port int, grpcWeb
Handler: &handlerSwitcher{
handler: mux,
urlToHandler: map[string]http.Handler{
"/api/badge": badge.NewHandler(server.AppClientset, server.settingsMgr, server.Namespace, server.ApplicationNamespaces),
common.LogoutEndpoint: logout.NewHandler(server.settingsMgr, server.sessionMgr, server.RootPath, server.BaseHRef),
"/api/badge": otelhttp.NewHandler(badge.NewHandler(server.AppClientset, server.settingsMgr, server.Namespace, server.ApplicationNamespaces), "server.ArgoCDServer/badge"),
common.LogoutEndpoint: otelhttp.NewHandler(logout.NewHandler(server.settingsMgr, server.sessionMgr, server.RootPath, server.BaseHRef), "server.ArgoCDServer/logout"),
},
contentTypeToHandler: map[string]http.Handler{
"application/grpc-web+proto": grpcWebHandler,
Expand Down Expand Up @@ -1293,7 +1300,7 @@ func registerExtensions(mux *http.ServeMux, a *ArgoCDServer, metricsReg HTTPMetr
extHandler := http.HandlerFunc(a.extensionManager.CallExtension())
authMiddleware := a.sessionMgr.AuthMiddlewareFunc(a.DisableAuth, a.settings.IsSSOConfigured(), a.ssoClientApp)
// auth middleware ensures that requests to all extensions are authenticated first
mux.Handle(extension.URLPrefix+"/", authMiddleware(extHandler))
mux.Handle(extension.URLPrefix+"/", otelhttp.NewHandler(authMiddleware(extHandler), "server.ArgoCDServer/extensions"))

a.extensionManager.AddMetricsRegistry(metricsReg)

Expand Down Expand Up @@ -1351,9 +1358,10 @@ func (server *ArgoCDServer) registerDexHandlers(mux *http.ServeMux) {
return
}
// Run dex OpenID Connect Identity Provider behind a reverse proxy (served at /api/dex)
mux.HandleFunc(common.DexAPIEndpoint+"/", dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig))
mux.HandleFunc(common.LoginEndpoint, server.ssoClientApp.HandleLogin)
mux.HandleFunc(common.CallbackEndpoint, server.ssoClientApp.HandleCallback)
mux.Handle(common.DexAPIEndpoint+"/", otelhttp.NewHandler(http.HandlerFunc(dexutil.NewDexHTTPReverseProxy(server.DexServerAddr, server.BaseHRef, server.DexTLSConfig)), "server.dex/Proxy"))

mux.Handle(common.LoginEndpoint, otelhttp.NewHandler(http.HandlerFunc(server.ssoClientApp.HandleLogin), "server.ClientApp/HandleLogin"))
mux.Handle(common.CallbackEndpoint, otelhttp.NewHandler(http.HandlerFunc(server.ssoClientApp.HandleCallback), "server.ClientApp/HandleCallback"))
}

// newRedirectServer returns an HTTP server which does a 307 redirect to the HTTPS server
Expand Down Expand Up @@ -1510,6 +1518,9 @@ func replaceBaseHRef(data string, replaceWith string) string {

// Authenticate checks for the presence of a valid token when accessing server-side resources.
func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "server.ArgoCDServer.Authenticate")
defer span.End()
if server.DisableAuth {
return ctx, nil
}
Expand Down Expand Up @@ -1549,18 +1560,24 @@ func (server *ArgoCDServer) Authenticate(ctx context.Context) (context.Context,

// getClaims extracts, validates and refreshes a JWT token from an incoming request context.
func (server *ArgoCDServer) getClaims(ctx context.Context) (jwt.Claims, string, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "server.ArgoCDServer.getClaims")
defer span.End()
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
span.SetStatus(otel_codes.Error, ErrNoSession.Error())
return nil, "", ErrNoSession
}
tokenString := getToken(md)
if tokenString == "" {
span.SetStatus(otel_codes.Error, ErrNoSession.Error())
return nil, "", ErrNoSession
}
// A valid argocd-issued token is automatically refreshed here prior to expiration.
// OIDC tokens will be verified but will not be refreshed here.
claims, newToken, err := server.sessionMgr.VerifyToken(ctx, tokenString)
if err != nil {
span.SetStatus(otel_codes.Error, err.Error())
return claims, "", status.Errorf(codes.Unauthenticated, "invalid session: %v", err)
}

Expand Down
109 changes: 93 additions & 16 deletions util/oidc/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ import (
"sync"
"time"

"go.opentelemetry.io/otel/codes"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

gooidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
Expand All @@ -38,6 +44,13 @@ import (

var ErrInvalidRedirectURL = errors.New("invalid return URL")

// OpenTelemetry tracer for this package
var tracer trace.Tracer

func init() {
tracer = otel.Tracer("github.com/argoproj/argo-cd/v3/util/oidc")
}

const (
GrantTypeAuthorizationCode = "authorization_code"
GrantTypeImplicit = "implicit"
Expand Down Expand Up @@ -147,6 +160,9 @@ func GetOidcTokenCacheFromJSON(jsonBytes []byte) (*OidcTokenCache, error) {
// GetTokenSourceFromCache creates an oauth2 TokenSource from a cached oidc token. The TokenSource will be configured
// with an early expiration based on the refreshTokenThreshold.
func (a *ClientApp) GetTokenSourceFromCache(ctx context.Context, oidcTokenCache *OidcTokenCache) (oauth2.TokenSource, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetTokenSourceFromCache")
defer span.End()
if oidcTokenCache == nil {
return nil, errors.New("oidcTokenCache is required")
}
Expand Down Expand Up @@ -198,10 +214,18 @@ func NewClientApp(settings *settings.ArgoCDSettings, dexServerAddr string, dexTL

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
_, span := tracer.Start(ctx, "oidc.ClientApp.client")
defer span.End()
span.SetAttributes(
attribute.String("network", network),
attribute.String("addr", addr),
)
return (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial(network, addr)
},
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
Expand Down Expand Up @@ -541,7 +565,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
}
// save the accessToken in memory for later use
sub := jwtutil.StringField(claims, "sub")
err = a.SetValueInEncryptedCache(FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
err = a.SetValueInEncryptedCache(ctx, FormatAccessTokenCacheKey(sub), []byte(token.AccessToken), GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
log.Errorf("cannot cache encrypted accessToken: %v (claims=%s)", err, claimsJSON)
Expand All @@ -557,7 +581,7 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
return
}
sid := jwtutil.StringField(claims, "sid")
err = a.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
err = a.SetValueInEncryptedCache(ctx, formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, GetTokenExpiration(claims))
if err != nil {
claimsJSON, _ := json.Marshal(claims)
log.Errorf("cannot cache encrypted oidc token: %v (claims=%s)", err, claimsJSON)
Expand Down Expand Up @@ -587,56 +611,87 @@ func (a *ClientApp) HandleCallback(w http.ResponseWriter, r *http.Request) {
// GetValueFromEncryptedCache is a convenience method for retreiving a value from cache and decrypting it. If the cache
// does not contain a value for the given key, a nil value is returned. Return handling should check for error and then
// check for nil.
func (a *ClientApp) GetValueFromEncryptedCache(key string) (value []byte, err error) {
func (a *ClientApp) GetValueFromEncryptedCache(ctx context.Context, key string) (value []byte, err error) {
_, span := tracer.Start(ctx, "oidc.ClientApp.GetValueFromEncryptedCache")
defer span.End()
var encryptedValue []byte
span.AddEvent("start cache read")
err = a.clientCache.Get(key, &encryptedValue)
span.AddEvent("end cache read")
if err != nil {
if errors.Is(err, cache.ErrCacheMiss) {
span.SetAttributes(attribute.Bool("cache_hit", false))
// Return nil to signify a cache miss
return nil, nil
}
return nil, fmt.Errorf("failed to get encrypted value from cache: %w", err)
err = fmt.Errorf("failed to get encrypted value from cache: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.SetAttributes(attribute.Bool("cache_hit", true))
value, err = crypto.Decrypt(encryptedValue, a.encryptionKey)
if err != nil {
return nil, fmt.Errorf("failed to decrypt value from cache: %w", err)
err = fmt.Errorf("failed to decrypt value from cache: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
return value, err
}

// SetValueFromEncyrptedCache is a convenience method for encrypting a value and storing it in the cache at a given key.
// Cache expiration is set based on input.
func (a *ClientApp) SetValueInEncryptedCache(key string, value []byte, expiration time.Duration) error {
func (a *ClientApp) SetValueInEncryptedCache(ctx context.Context, key string, value []byte, expiration time.Duration) error {
_, span := tracer.Start(ctx, "oidc.ClientApp.SetValueInEncryptedCache")
defer span.End()
encryptedValue, err := crypto.Encrypt(value, a.encryptionKey)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}
span.SetAttributes(
attribute.String("key", key),
attribute.Int("value_length", len(value)),
)
span.AddEvent("start cache write")
err = a.clientCache.Set(&cache.Item{
Key: key,
Object: encryptedValue,
CacheActionOpts: cache.CacheActionOpts{
Expiration: expiration,
},
})
span.AddEvent("end cache write")
if err != nil {
span.SetStatus(codes.Error, err.Error())
return err
}
return nil
}

func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.MapClaims, refreshTokenThreshold time.Duration) (string, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.CheckAndRefreshToken")
defer span.End()
iss := jwtutil.StringField(groupClaims, "iss")
sub := jwtutil.StringField(groupClaims, "sub")
sid := jwtutil.StringField(groupClaims, "sid")
span.SetAttributes(
attribute.String("iss", iss),
attribute.String("sub", sub),
attribute.String("sid", sid))
if GetTokenExpiration(groupClaims) < refreshTokenThreshold {
token, err := a.GetUpdatedOidcTokenFromCache(ctx, sub, sid)
if err != nil {
log.Errorf("Failed to get token from cache: %v", err)
span.SetStatus(codes.Error, err.Error())
return "", err
}
if token != nil {
idTokenRAW, ok := token.Extra("id_token").(string)
if !ok {
return "", errors.New("empty id_token")
err = errors.New("empty id_token")
span.SetStatus(codes.Error, err.Error())
return "", err
}
return idTokenRAW, nil
}
Expand All @@ -647,12 +702,21 @@ func (a *ClientApp) CheckAndRefreshToken(ctx context.Context, groupClaims jwt.Ma
// GetUpdatedOidcTokenFromCache fetches a token from cache and refreshes it if under the threshold for expiration.
// The cached token will also be updated if it is refreshed. Returns latest token or an error if the process fails.
func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject string, sessionId string) (*oauth2.Token, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUpdatedOidcTokenFromCache")
defer span.End()

ctx = gooidc.ClientContext(ctx, a.client)
span.SetAttributes(
attribute.String("subject", subject),
attribute.String("sessionId", sessionId),
)

// Get oauth2 config
cacheKey := formatOidcTokenCacheKey(subject, sessionId)
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(cacheKey)
oidcTokenCacheJSON, err := a.GetValueFromEncryptedCache(ctx, cacheKey)
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if oidcTokenCacheJSON == nil {
Expand All @@ -662,25 +726,35 @@ func (a *ClientApp) GetUpdatedOidcTokenFromCache(ctx context.Context, subject st
oidcTokenCache, err := GetOidcTokenCacheFromJSON(oidcTokenCacheJSON)
if err != nil {
err = fmt.Errorf("failed to unmarshal cached oidc token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
tokenSource, err := a.GetTokenSourceFromCache(ctx, oidcTokenCache)
if err != nil {
err = fmt.Errorf("failed to get token source from cached oidc token: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
span.AddEvent("starting tokenSource.Token()")
token, err := tokenSource.Token()
span.AddEvent("finished tokenSource.Token()")
if err != nil {
return nil, fmt.Errorf("failed to refresh token from source: %w", err)
err = fmt.Errorf("failed to refresh token from source: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
if token.AccessToken != oidcTokenCache.Token.AccessToken {
span.AddEvent("updating cache with latest token")
oidcTokenCache = NewOidcTokenCache(oidcTokenCache.RedirectURL, token)
oidcTokenCacheJSON, err = json.Marshal(oidcTokenCache)
if err != nil {
return nil, fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
err = fmt.Errorf("failed to marshal oidc oidcTokenCache refresher: %w", err)
span.SetStatus(codes.Error, err.Error())
return nil, err
}
err = a.SetValueInEncryptedCache(cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
err = a.SetValueInEncryptedCache(ctx, cacheKey, oidcTokenCacheJSON, time.Until(token.Expiry))
if err != nil {
span.SetStatus(codes.Error, err.Error())
return nil, err
}
}
Expand Down Expand Up @@ -827,6 +901,9 @@ func (a *ClientApp) SetGroupsFromUserInfo(ctx context.Context, claims jwt.Claims

// GetUserInfo queries the IDP userinfo endpoint for claims
func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims, issuerURL, userInfoPath string) (jwt.MapClaims, bool, error) {
var span trace.Span
ctx, span = tracer.Start(ctx, "oidc.ClientApp.GetUserInfo")
defer span.End()
sub := jwtutil.StringField(actualClaims, "sub")
var claims jwt.MapClaims
var encClaims []byte
Expand All @@ -848,7 +925,7 @@ func (a *ClientApp) GetUserInfo(ctx context.Context, actualClaims jwt.MapClaims,
}

// check if the accessToken for the user is still present
accessTokenBytes, err := a.GetValueFromEncryptedCache(FormatAccessTokenCacheKey(sub))
accessTokenBytes, err := a.GetValueFromEncryptedCache(ctx, FormatAccessTokenCacheKey(sub))
if err != nil {
return claims, true, fmt.Errorf("could not read accessToken from cache for %s: %w", sub, err)
}
Expand Down
4 changes: 2 additions & 2 deletions util/oidc/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1418,7 +1418,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL),
if tt.insertIntoCache {
oidcTokenCacheJSON, err := json.Marshal(tt.oidcTokenCache)
require.NoError(t, err)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(tt.subject, tt.session), oidcTokenCacheJSON, time.Minute))
}
token, err := app.GetUpdatedOidcTokenFromCache(t.Context(), tt.subject, tt.session)
if tt.expectErrorContains != "" {
Expand Down Expand Up @@ -1509,7 +1509,7 @@ requestedScopes: ["oidc"]`, oidcTestServer.URL, tt.refreshTokenThreshold),
require.NotEmpty(t, sub)
sid := jwtutil.StringField(tt.groupClaims, "sid")
require.NotEmpty(t, sid)
require.NoError(t, app.SetValueInEncryptedCache(formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
require.NoError(t, app.SetValueInEncryptedCache(t.Context(), formatOidcTokenCacheKey(sub, sid), oidcTokenCacheJSON, time.Minute))
token, err := app.CheckAndRefreshToken(t.Context(), tt.groupClaims, cdSettings.RefreshTokenThreshold())
if tt.expectErrorContains != "" {
require.ErrorContains(t, err, tt.expectErrorContains)
Expand Down
Loading
Loading