Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions service/internal/auth/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"google.golang.org/grpc/metadata"

sdkAudit "github.com/opentdf/platform/sdk/audit"
"github.com/opentdf/platform/service/logger"
Expand Down Expand Up @@ -71,6 +70,9 @@
ActionDelete = "delete"
ActionUnsafe = "unsafe"
ActionOther = "other"

mdAccessTokenKey = "access_token"

Check failure on line 74 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

const mdAccessTokenKey is unused (unused)
mdClientIDKey = "client_id"

Check failure on line 75 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

const mdClientIDKey is unused (unused)
)

// Authentication holds a jwks cache and information about the openid configuration
Expand Down Expand Up @@ -242,12 +244,16 @@
return
}

md, ok := metadata.FromIncomingContext(ctxWithJWK)
if !ok {
md = metadata.New(nil)
var clientID string
clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim
if clientIDClaim != "" {
if id, ok := accessTok.Get(clientIDClaim); ok {
if clientIDClaimValue, ok := id.(string); ok {

Check failure on line 251 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

shadow: declaration of "ok" shadows declaration at line 250 (govet)
clientID = clientIDClaimValue
}
}
}
md.Append("access_token", ctxAuth.GetRawAccessTokenFromContext(ctxWithJWK, nil))
ctxWithJWK = metadata.NewIncomingContext(ctxWithJWK, md)
ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID)

// Check if the token is allowed to access the resource
var action string
Expand All @@ -266,6 +272,8 @@
a.logger.WarnContext(r.Context(),
"permission denied",
slog.String("azp", accessTok.Subject()),
slog.String("configured_client_id_claim_name", clientIDClaim),
slog.String("client_id", clientID),
slog.Any("error", err),
)
http.Error(w, "permission denied", http.StatusForbidden)
Expand All @@ -274,12 +282,18 @@
http.Error(w, "internal server error", http.StatusInternalServerError)
return
} else if !allow {
a.logger.WarnContext(r.Context(), "permission denied", slog.String("azp", accessTok.Subject()))
a.logger.WarnContext(
r.Context(),
"permission denied",
slog.String("azp", accessTok.Subject()),
slog.String("configured_client_id_claim_name", clientIDClaim),
slog.String("client_id", clientID),
)
http.Error(w, "permission denied", http.StatusForbidden)
return
}

r = r.WithContext(ctxWithJWK)
r = r.WithContext(ctxWithMetadata)
handler.ServeHTTP(w, r)
})
}
Expand Down Expand Up @@ -319,7 +333,7 @@
resource := p[1] + "/" + p[2]
action := getAction(p[2])

token, newCtx, err := a.checkToken(
token, ctxWithJWK, err := a.checkToken(
ctx,
header,
ri,
Expand All @@ -329,11 +343,24 @@
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("unauthenticated"))
}

var clientID string
clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim
if clientIDClaim != "" {
if id, ok := token.Get(clientIDClaim); ok {
if idStr, ok := id.(string); ok {

Check failure on line 350 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

shadow: declaration of "ok" shadows declaration at line 349 (govet)
clientID = idStr
}
}
}
ctxWithMetadata := ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID)

// Check if the token is allowed to access the resource
if allowed, err := a.enforcer.Enforce(token, resource, action); err != nil {
if err.Error() == "permission denied" {
a.logger.Warn("permission denied",
slog.String("azp", token.Subject()),
slog.String("configured_client_id_claim_name", clientIDClaim),
slog.String("client_id", clientID),
slog.Any("error", err),
)
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
Expand All @@ -344,7 +371,7 @@
return nil, connect.NewError(connect.CodePermissionDenied, errors.New("permission denied"))
}

return next(newCtx, req)
return next(ctxWithMetadata, req)
})
}
return connect.UnaryInterceptorFunc(interceptor)
Expand Down Expand Up @@ -431,12 +458,12 @@
ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw)
return accessToken, ctx, nil
}
key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
dpopKey, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
if err != nil {
a.logger.Warn("failed to validate dpop", slog.Any("err", err))
return nil, nil, err
}
ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw)
ctx = ctxAuth.ContextWithAuthNInfo(ctx, dpopKey, accessToken, tokenRaw)
return accessToken, ctx, nil
}

Expand Down Expand Up @@ -657,7 +684,7 @@
func (a Authentication) ipcReauthCheck(ctx context.Context, path string, header http.Header) (context.Context, error) {
for _, route := range a.ipcReauthRoutes {
reqPath := path
if reqPath == route {

Check failure on line 687 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

`if reqPath == route` has complex nested blocks (complexity: 8) (nestif)
// Extract the token from the request
authHeader := header["Authorization"]
if len(authHeader) < 1 {
Expand All @@ -668,7 +695,7 @@
u = append(u, a.lookupGatewayPaths(ctx, path, header)...)

// Validate the token and create a JWT token
_, nextCtx, err := a.checkToken(ctx, authHeader, receiverInfo{
token, ctxWithJWK, err := a.checkToken(ctx, authHeader, receiverInfo{
u: u,
m: []string{http.MethodPost},
}, header["Dpop"])
Expand All @@ -677,7 +704,15 @@
}

// Return the next context with the token
return nextCtx, nil
var clientID string
if clientIDClaim := a.oidcConfiguration.Policy.ClientIDClaim; clientIDClaim != "" {
if id, ok := token.Get(clientIDClaim); ok {
if idStr, ok := id.(string); ok {

Check failure on line 710 in service/internal/auth/authn.go

View workflow job for this annotation

GitHub Actions / go (service)

shadow: declaration of "ok" shadows declaration at line 709 (govet)
clientID = idStr
}
}
}
return ctxAuth.ContextWithAuthnMetadata(ctxWithJWK, clientID), nil
}
}
return ctx, nil
Expand Down
87 changes: 85 additions & 2 deletions service/internal/auth/authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/wrapperspb"
)
Expand All @@ -56,6 +57,7 @@
}

type FakeAccessServiceServer struct {
clientID string
accessToken []string
dpopKey jwk.Key
kas.UnimplementedAccessServiceServer
Expand All @@ -72,6 +74,7 @@
func (f *FakeAccessServiceServer) Rewrap(ctx context.Context, req *connect.Request[kas.RewrapRequest]) (*connect.Response[kas.RewrapResponse], error) {
f.accessToken = req.Header()["Authorization"]
f.dpopKey = ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger())
f.clientID, _ = ctxAuth.GetClientIDFromContext(ctx)

return &connect.Response[kas.RewrapResponse]{Msg: &kas.RewrapResponse{}}, nil
}
Expand Down Expand Up @@ -148,7 +151,9 @@
}
}))

policyCfg := PolicyConfig{}
policyCfg := PolicyConfig{
ClientIDClaim: "cid",
}
err = defaults.Set(&policyCfg)
s.Require().NoError(err)

Expand Down Expand Up @@ -214,6 +219,8 @@
func (s *AuthSuite) Test_IPCUnaryServerInterceptor() {
// Mock the checkToken method to return a valid token and context
mockToken := jwt.New()
mockToken.Set("cid", "mockClientID")

Check failure on line 222 in service/internal/auth/authn_test.go

View workflow job for this annotation

GitHub Actions / go (service)

Error return value of `mockToken.Set` is not checked (errcheck)

type contextKey string
mockCtx := context.WithValue(context.Background(), contextKey("mockKey"), "mockValue")
s.auth._testCheckTokenFunc = func(_ context.Context, authHeader []string, _ receiverInfo, _ []string) (jwt.Token, context.Context, error) {
Expand All @@ -234,6 +241,9 @@
s.Require().NoError(err)
s.Require().NotNil(nextCtx)
s.Equal("mockValue", nextCtx.Value(contextKey("mockKey")))
clientID, err := ctxAuth.GetClientIDFromContext(nextCtx)
s.Require().NoError(err)
s.Equal("mockClientID", clientID)

// Test with a route not requiring reauthorization
nextCtx, err = s.auth.ipcReauthCheck(context.Background(), "/kas.AccessService/PublicKey", nil)
Expand All @@ -254,6 +264,63 @@
s.Contains(err.Error(), "unauthenticated")
}

func (s *AuthSuite) Test_ConnectUnaryServerInterceptor_ClientIDPropagated() {
tok := jwt.New()
s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)))
s.Require().NoError(tok.Set("iss", s.server.URL))
s.Require().NoError(tok.Set("aud", "test"))
// default client ID claim in policy config is 'azp'
s.Require().NoError(tok.Set("azp", "test-client-id"))
s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}}))

policyCfg := new(PolicyConfig)
err := defaults.Set(policyCfg)
s.Require().NoError(err)

authnConfig := AuthNConfig{
Issuer: s.server.URL,
Audience: "test",
Policy: *policyCfg,
}
config := Config{
AuthNConfig: authnConfig,
}
auth, err := NewAuthenticator(context.Background(), config, &logger.Logger{
Logger: slog.New(slog.Default().Handler()),
}, func(_ string, _ any) error { return nil })
s.Require().NoError(err)

// Sign the token
signedTok, err := jwt.Sign(tok, jwt.WithKey(jwa.RS256, s.key))
s.Require().NoError(err)

// Create a minimal connect server setup to properly test the interceptor
// This is necessary because connect requests need proper procedure routing
interceptor := connect.WithInterceptors(auth.ConnectUnaryServerInterceptor())

fakeServer := &FakeAccessServiceServer{}
mux := http.NewServeMux()
path, handler := kasconnect.NewAccessServiceHandler(fakeServer, interceptor)
mux.Handle(path, handler)

server := memhttp.New(mux)
defer server.Close()

// Create a connect client that sends a Bearer token
conn, _ := grpc.NewClient("passthrough://bufconn", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return server.Listener.DialContext(ctx, "tcp", "http://localhost:8080")
}), grpc.WithTransportCredentials(insecure.NewCredentials()))

client := kas.NewAccessServiceClient(conn)

// Make the request
_, err = client.Rewrap(metadata.AppendToOutgoingContext(s.T().Context(), "authorization", "Bearer "+string(signedTok)), &kas.RewrapRequest{})
s.Require().NoError(err)

// Assert that the client ID was properly extracted and set in the context
s.Equal("test-client-id", fakeServer.clientID)
}

func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() {
tok := jwt.New()
s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)))
Expand Down Expand Up @@ -482,7 +549,7 @@
s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Now().Add(time.Hour)))
s.Require().NoError(tok.Set("iss", s.server.URL))
s.Require().NoError(tok.Set("aud", "test"))
s.Require().NoError(tok.Set("cid", "client2"))
s.Require().NoError(tok.Set("cid", "client-123"))
s.Require().NoError(tok.Set("realm_access", map[string][]string{"roles": {"opentdf-standard"}}))
thumbprint, err := dpopKey.Thumbprint(crypto.SHA256)
s.Require().NoError(err)
Expand Down Expand Up @@ -517,6 +584,10 @@

_, err = client.Rewrap(context.Background(), &kas.RewrapRequest{})
s.Require().NoError(err)

Check failure on line 587 in service/internal/auth/authn_test.go

View workflow job for this annotation

GitHub Actions / go (service)

File is not properly formatted (gofumpt)
// interceptor propagated clientID from the token at the configured claim
s.Equal(fakeServer.clientID, "client-123")

Check failure on line 589 in service/internal/auth/authn_test.go

View workflow job for this annotation

GitHub Actions / go (service)

expected-actual: need to reverse actual and expected values (testifylint)

s.NotNil(fakeServer.dpopKey)
dpopJWKFromRequest, ok := fakeServer.dpopKey.(jwk.RSAPublicKey)
s.True(ok)
Expand Down Expand Up @@ -552,12 +623,15 @@

jwkChan := make(chan jwk.Key, 1)
timeout := make(chan string, 1)
clientIDChan := make(chan string, 1)
go func() {
time.Sleep(5 * time.Second)
timeout <- ""
}()
server := httptest.NewServer(s.auth.MuxHandler(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) {
jwkChan <- ctxAuth.GetJWKFromContext(req.Context(), logger.CreateTestLogger())
cid, _ := ctxAuth.GetClientIDFromContext(req.Context())
clientIDChan <- cid
})))
defer server.Close()

Expand Down Expand Up @@ -585,6 +659,15 @@
case <-timeout:
s.Require().FailNow("timed out waiting for call to complete")
}
var clientID string
select {
case cid := <-clientIDChan:
clientID = cid
case <-timeout:
s.Require().FailNow("timed out waiting for call to complete")
}

s.Equal(clientID, "client2")

Check failure on line 670 in service/internal/auth/authn_test.go

View workflow job for this annotation

GitHub Actions / go (service)

expected-actual: need to reverse actual and expected values (testifylint)

s.NotNil(dpopKeyFromRequest)
dpopJWKFromRequest, ok := dpopKeyFromRequest.(jwk.RSAPublicKey)
Expand Down
2 changes: 2 additions & 0 deletions service/internal/auth/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ type PolicyConfig struct {
UserNameClaim string `mapstructure:"username_claim" json:"username_claim" default:"preferred_username"`
// Claim to use for group/role information
GroupsClaim string `mapstructure:"groups_claim" json:"groups_claim" default:"realm_access.roles"`
// Claim to use to reference idP clientID
ClientIDClaim string `mapstructure:"client_id_claim" json:"client_id_claim" default:"azp"`
// Deprecated: Use GroupClain instead
RoleClaim string `mapstructure:"claim" json:"claim" default:"realm_access.roles"`
// Deprecated: Use Casbin grouping statements g, <user/group>, <role>
Expand Down
Loading
Loading