diff --git a/service/internal/auth/authn.go b/service/internal/auth/authn.go index 7ad326f784..75a5f46aef 100644 --- a/service/internal/auth/authn.go +++ b/service/internal/auth/authn.go @@ -24,20 +24,10 @@ import ( sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/logger" -) -const ( - authnContextKey = authContextKey("dpop-jwk") + ctxAuth "github.com/opentdf/platform/service/pkg/auth" ) -type authContextKey string - -type authContext struct { - key jwk.Key - accessToken jwt.Token - rawToken string -} - var ( // Set of allowed public endpoints that do not require authentication allowedPublicEndpoints = [...]string{ @@ -394,7 +384,7 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo if !tokenHasCNF && !a.enforceDPoP { // this condition is not quite tight because it's possible that the `cnf` claim may // come from token introspection - ctx = ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, nil, accessToken, tokenRaw) return accessToken, ctx, nil } key, err := a.validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader) @@ -402,53 +392,10 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo a.logger.Warn("failed to validate dpop", slog.String("token", tokenRaw), slog.Any("err", err)) return nil, nil, err } - ctx = ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, accessToken, tokenRaw) return accessToken, ctx, nil } -func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context { - return context.WithValue(ctx, authnContextKey, &authContext{ - key, - accessToken, - raw, - }) -} - -func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { - key := ctx.Value(authnContextKey) - if key == nil { - return nil - } - if c, ok := key.(*authContext); ok { - return c - } - - // We should probably return an error here? - l.ErrorContext(ctx, "invalid authContext") - return nil -} - -func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key { - if c := getContextDetails(ctx, l); c != nil { - return c.key - } - return nil -} - -func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token { - if c := getContextDetails(ctx, l); c != nil { - return c.accessToken - } - return nil -} - -func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string { - if c := getContextDetails(ctx, l); c != nil { - return c.rawToken - } - return "" -} - func (a Authentication) validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo receiverInfo, headers []string) (jwk.Key, error) { if len(headers) != 1 { return nil, fmt.Errorf("got %d dpop headers, should have 1", len(headers)) diff --git a/service/internal/auth/authn_test.go b/service/internal/auth/authn_test.go index 9a7d86d663..b28d011f24 100644 --- a/service/internal/auth/authn_test.go +++ b/service/internal/auth/authn_test.go @@ -31,6 +31,7 @@ import ( sdkauth "github.com/opentdf/platform/sdk/auth" "github.com/opentdf/platform/service/internal/server/memhttp" "github.com/opentdf/platform/service/logger" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -69,7 +70,7 @@ func (f *FakeAccessServiceServer) LegacyPublicKey(_ context.Context, _ *connect. func (f *FakeAccessServiceServer) Rewrap(ctx context.Context, req *connect.Request[kas.RewrapRequest]) (*connect.Response[kas.RewrapResponse], error) { f.accessToken = req.Header()["Authorization"] - f.dpopKey = GetJWKFromContext(ctx, logger.CreateTestLogger()) + f.dpopKey = ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger()) return &connect.Response[kas.RewrapResponse]{Msg: &kas.RewrapResponse{}}, nil } @@ -512,7 +513,7 @@ func (s *AuthSuite) TestDPoPEndToEnd_HTTP() { timeout <- "" }() server := httptest.NewServer(s.auth.MuxHandler(http.HandlerFunc(func(_ http.ResponseWriter, req *http.Request) { - jwkChan <- GetJWKFromContext(req.Context(), logger.CreateTestLogger()) + jwkChan <- ctxAuth.GetJWKFromContext(req.Context(), logger.CreateTestLogger()) }))) defer server.Close() @@ -638,7 +639,7 @@ func (s *AuthSuite) Test_Allowing_Auth_With_No_DPoP() { _, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil) s.Require().NoError(err) - s.Require().Nil(GetJWKFromContext(ctx, logger.CreateTestLogger())) + s.Require().Nil(ctxAuth.GetJWKFromContext(ctx, logger.CreateTestLogger())) } func (s *AuthSuite) Test_PublicPath_Matches() { diff --git a/service/kas/access/rewrap.go b/service/kas/access/rewrap.go index 61dbcff432..ea4e538218 100644 --- a/service/kas/access/rewrap.go +++ b/service/kas/access/rewrap.go @@ -30,10 +30,10 @@ import ( kaspb "github.com/opentdf/platform/protocol/go/kas" "github.com/opentdf/platform/sdk" - "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/internal/security" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/logger/audit" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -128,7 +128,7 @@ func extractSRTBody(ctx context.Context, headers http.Header, in *kaspb.RewrapRe } // get dpop public key from context - dpopJWK := auth.GetJWKFromContext(ctx, &logger) + dpopJWK := ctxAuth.GetJWKFromContext(ctx, &logger) var err error var rbString string @@ -247,7 +247,7 @@ func verifyAndParsePolicy(ctx context.Context, requestBody *RequestBody, k []byt func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, error) { info := new(entityInfo) - token := auth.GetAccessTokenFromContext(ctx, logger) + token := ctxAuth.GetAccessTokenFromContext(ctx, logger) if token == nil { return nil, err401("missing access token") } @@ -263,7 +263,7 @@ func getEntityInfo(ctx context.Context, logger *logger.Logger) (*entityInfo, err logger.WarnContext(ctx, "missing sub") } - info.Token = auth.GetRawAccessTokenFromContext(ctx, logger) + info.Token = ctxAuth.GetRawAccessTokenFromContext(ctx, logger) return info, nil } diff --git a/service/kas/access/rewrap_test.go b/service/kas/access/rewrap_test.go index f75348a09d..a9eff6e975 100644 --- a/service/kas/access/rewrap_test.go +++ b/service/kas/access/rewrap_test.go @@ -17,8 +17,8 @@ import ( "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/opentdf/platform/lib/ocrypto" - "github.com/opentdf/platform/service/internal/auth" "github.com/opentdf/platform/service/logger" + ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -328,7 +328,7 @@ func TestParseAndVerifyRequest(t *testing.T) { require.NoError(t, err, "couldn't get JWK from key") err = key.Set(jwk.AlgorithmKey, jwa.RS256) // Check the error return value require.NoError(t, err, "failed to set algorithm key") - ctx = auth.ContextWithAuthNInfo(ctx, key, mockJWT(t), bearer) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, mockJWT(t), bearer) } md := metadata.New(map[string]string{"token": bearer}) @@ -370,7 +370,7 @@ func Test_SignedRequestBody_When_Bad_Signature_Expect_Failure(t *testing.T) { err = key.Set(jwk.AlgorithmKey, jwa.NoSignature) require.NoError(t, err, "failed to set algorithm key") - ctx = auth.ContextWithAuthNInfo(ctx, key, mockJWT(t), string(jwtStandard(t))) + ctx = ctxAuth.ContextWithAuthNInfo(ctx, key, mockJWT(t), string(jwtStandard(t))) md := metadata.New(map[string]string{"token": string(jwtWrongKey(t))}) ctx = metadata.NewIncomingContext(ctx, md) diff --git a/service/pkg/auth/context_auth.go b/service/pkg/auth/context_auth.go new file mode 100644 index 0000000000..9a1856282d --- /dev/null +++ b/service/pkg/auth/context_auth.go @@ -0,0 +1,64 @@ +package auth + +import ( + "context" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/opentdf/platform/service/logger" +) + +var ( + authnContextKey = authContextKey{} +) + +type authContextKey struct{} + +type authContext struct { + key jwk.Key + accessToken jwt.Token + rawToken string +} + +func ContextWithAuthNInfo(ctx context.Context, key jwk.Key, accessToken jwt.Token, raw string) context.Context { + return context.WithValue(ctx, authnContextKey, &authContext{ + key, + accessToken, + raw, + }) +} + +func getContextDetails(ctx context.Context, l *logger.Logger) *authContext { + key := ctx.Value(authnContextKey) + if key == nil { + return nil + } + if c, ok := key.(*authContext); ok { + return c + } + + // We should probably return an error here? + l.ErrorContext(ctx, "invalid authContext") + return nil +} + +func GetJWKFromContext(ctx context.Context, l *logger.Logger) jwk.Key { + if c := getContextDetails(ctx, l); c != nil { + return c.key + } + return nil +} + +func GetAccessTokenFromContext(ctx context.Context, l *logger.Logger) jwt.Token { + if c := getContextDetails(ctx, l); c != nil { + return c.accessToken + } + return nil +} + +func GetRawAccessTokenFromContext(ctx context.Context, l *logger.Logger) string { + if c := getContextDetails(ctx, l); c != nil { + return c.rawToken + } + return "" +} diff --git a/service/pkg/auth/context_auth_test.go b/service/pkg/auth/context_auth_test.go new file mode 100644 index 0000000000..eedb5b540f --- /dev/null +++ b/service/pkg/auth/context_auth_test.go @@ -0,0 +1,72 @@ +package auth + +import ( + "context" + "testing" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/opentdf/platform/service/logger" + "github.com/stretchr/testify/assert" +) + +func TestContextWithAuthNInfo(t *testing.T) { + // Create mock JWK, JWT, and raw token + mockJWK, _ := jwk.FromRaw([]byte("mockKey")) + mockJWT, _ := jwt.NewBuilder().Build() + rawToken := "mockRawToken" + + // Initialize context + ctx := context.Background() + newCtx := ContextWithAuthNInfo(ctx, mockJWK, mockJWT, rawToken) + + // Assert that the context contains the correct values + value := newCtx.Value(authnContextKey) + testAuthContext, ok := value.(*authContext) + assert.True(t, ok) + assert.NotNil(t, testAuthContext) + assert.Equal(t, mockJWK, testAuthContext.key, "JWK should match") + assert.Equal(t, mockJWT, testAuthContext.accessToken, "JWT should match") + assert.Equal(t, rawToken, testAuthContext.rawToken, "Raw token should match") +} + +func TestGetJWKFromContext(t *testing.T) { + // Create mock context with JWK + mockJWK, _ := jwk.FromRaw([]byte("mockKey")) + ctx := ContextWithAuthNInfo(context.Background(), mockJWK, nil, "") + + // Retrieve the JWK and assert + retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + assert.NotNil(t, retrievedJWK, "JWK should not be nil") + assert.Equal(t, mockJWK, retrievedJWK, "Retrieved JWK should match the mock JWK") +} + +func TestGetAccessTokenFromContext(t *testing.T) { + // Create mock context with JWT + mockJWT, _ := jwt.NewBuilder().Build() + ctx := ContextWithAuthNInfo(context.Background(), nil, mockJWT, "") + + // Retrieve the JWT and assert + retrievedJWT := GetAccessTokenFromContext(ctx, logger.CreateTestLogger()) + assert.NotNil(t, retrievedJWT, "Access token should not be nil") + assert.Equal(t, mockJWT, retrievedJWT, "Retrieved JWT should match the mock JWT") +} + +func TestGetRawAccessTokenFromContext(t *testing.T) { + // Create mock context with raw token + rawToken := "mockRawToken" + ctx := ContextWithAuthNInfo(context.Background(), nil, nil, rawToken) + + // Retrieve the raw token and assert + retrievedRawToken := GetRawAccessTokenFromContext(ctx, logger.CreateTestLogger()) + assert.Equal(t, rawToken, retrievedRawToken, "Retrieved raw token should match the mock raw token") +} + +func TestGetContextDetailsInvalidType(t *testing.T) { + // Create a context with an invalid type + ctx := context.WithValue(context.Background(), authnContextKey, "invalidType") + + // Assert that GetJWKFromContext handles the invalid type correctly + retrievedJWK := GetJWKFromContext(ctx, logger.CreateTestLogger()) + assert.Nil(t, retrievedJWK, "JWK should be nil when context value is invalid") +}