Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 39 additions & 25 deletions service/internal/auth/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"net/http"
"net/url"
"path/filepath"
"slices"
"strings"
Expand Down Expand Up @@ -37,6 +38,7 @@ var (
"/kas.AccessService/PublicKey",
"/healthz",
"/.well-known/opentdf-configuration",
"/kas/kas_public_key",
"/kas/v2/kas_public_key",
}
// only asymmetric algorithms and no 'none'
Expand Down Expand Up @@ -125,10 +127,21 @@ func NewAuthenticator(ctx context.Context, cfg Config) (*Authentication, error)
return a, nil
}

type dpopInfo struct {
headers []string
path string
method string
type receiverInfo struct {
// The URI of the request
u string
// The HTTP method of the request
m string
}

func normalizeURL(o string, u *url.URL) string {
// Currently this does not do a full normatlization
ou, err := url.Parse(o)
if err != nil {
return u.String()
}
ou.Path = u.Path
return ou.String()
}

// verifyTokenHandler is a http handler that verifies the token
Expand All @@ -145,11 +158,11 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
http.Error(w, "missing authorization header", http.StatusUnauthorized)
return
}
tok, newCtx, err := a.checkToken(r.Context(), header, dpopInfo{
headers: r.Header["Dpop"],
path: r.URL.Path,
method: r.Method,
})
origin := r.Header.Get("Origin")
tok, newCtx, err := a.checkToken(r.Context(), header, receiverInfo{
u: normalizeURL(origin, r.URL),
m: r.Method,
}, r.Header["Dpop"])

if err != nil {
slog.WarnContext(r.Context(), "failed to validate token", slog.String("error", err.Error()))
Expand Down Expand Up @@ -227,11 +240,11 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
token, newCtx, err := a.checkToken(
ctx,
header,
dpopInfo{
headers: md["dpop"],
path: info.FullMethod,
method: http.MethodPost,
receiverInfo{
u: info.FullMethod,
m: http.MethodPost,
},
md["dpop"],
)
if err != nil {
slog.Warn("failed to validate token", slog.String("error", err.Error()))
Expand All @@ -254,7 +267,7 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
}

// checkToken is a helper function to verify the token.
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo dpopInfo) (jwt.Token, context.Context, error) {
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo receiverInfo, dpopHeader []string) (jwt.Token, context.Context, error) {
var (
tokenRaw string
)
Expand Down Expand Up @@ -313,7 +326,7 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo
// come from token introspection
return accessToken, ctx, nil
}
key, err := validateDPoP(accessToken, tokenRaw, dpopInfo)
key, err := validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
if err != nil {
return nil, nil, err
}
Expand All @@ -336,11 +349,11 @@ func GetJWKFromContext(ctx context.Context) jwk.Key {
panic("got something that is not a jwk.Key from the JWK context")
}

func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo) (jwk.Key, error) {
if len(dpopInfo.headers) != 1 {
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(dpopInfo.headers))
func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo receiverInfo, headers []string) (jwk.Key, error) {
if len(headers) != 1 {
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(headers))
}
dpopHeader := dpopInfo.headers[0]
dpopHeader := headers[0]

cnf, ok := accessToken.Get("cnf")
if !ok {
Expand Down Expand Up @@ -401,8 +414,9 @@ func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo
return nil, fmt.Errorf("couldn't compute thumbprint for key in `jwk` in DPoP JWT")
}

if base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(thumbprint) != jkt {
return nil, fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token")
thumbprintStr := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(thumbprint)
if thumbprintStr != jkt {
return nil, fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token; cnf.jkt=[%v], computed=[%v]", jkt, thumbprintStr)
}

// at this point we have the right key because its thumbprint matches the `jkt` claim
Expand All @@ -428,17 +442,17 @@ func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo
return nil, fmt.Errorf("`htm` claim missing in DPoP JWT")
}

if htm != dpopInfo.method {
return nil, fmt.Errorf("incorrect `htm` claim in DPoP JWT")
if htm != dpopInfo.m {
return nil, fmt.Errorf("incorrect `htm` claim in DPoP JWT; should match [%v]", dpopInfo.m)
}

htu, ok := dpopToken.Get("htu")
if !ok {
return nil, fmt.Errorf("`htu` claim missing in DPoP JWT")
}

if htu != dpopInfo.path {
return nil, fmt.Errorf("incorrect `htu` claim in DPoP JWT")
if htu != dpopInfo.u {
return nil, fmt.Errorf("incorrect `htu` claim in DPoP JWT; should match %v", dpopInfo.u)
}

ath, ok := dpopToken.Get("ath")
Expand Down
46 changes: 33 additions & 13 deletions service/internal/auth/authn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"slices"
"testing"
"time"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/opentdf/platform/protocol/go/kas"
sdkauth "github.com/opentdf/platform/sdk/auth"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand All @@ -32,6 +34,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
"google.golang.org/protobuf/types/known/wrapperspb"
"gotest.tools/v3/assert"
)

type AuthSuite struct {
Expand Down Expand Up @@ -164,6 +167,23 @@ func TestAuthSuite(t *testing.T) {
suite.Run(t, new(AuthSuite))
}

func TestNormalizeUrl(t *testing.T) {
for _, tt := range []struct {
origin, path, out string
}{
{"http://localhost", "/", "http://localhost/"},
{"https://localhost", "/somewhere", "https://localhost/somewhere"},
{"http://localhost", "", "http://localhost"},
} {
t.Run(tt.origin+tt.path, func(t *testing.T) {
u, err := url.Parse(tt.path)
require.NoError(t, err)
s := normalizeURL(tt.origin, u)
assert.Equal(t, s, tt.out)
})
}
}

func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() {
tok := jwt.New()
s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)))
Expand All @@ -173,7 +193,7 @@ func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("\"exp\" not satisfied", err.Error())
}
Expand All @@ -197,7 +217,7 @@ func (s *AuthSuite) Test_UnaryServerInterceptor_When_Authorization_Header_Missin
}

func (s *AuthSuite) Test_CheckToken_When_Authorization_Header_Invalid_Expect_Error() {
_, _, err := s.auth.checkToken(context.Background(), []string{"BPOP "}, dpopInfo{})
_, _, err := s.auth.checkToken(context.Background(), []string{"BPOP "}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("not of type bearer or dpop", err.Error())
}
Expand All @@ -211,7 +231,7 @@ func (s *AuthSuite) Test_CheckToken_When_Missing_Issuer_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("missing issuer", err.Error())
}
Expand All @@ -226,7 +246,7 @@ func (s *AuthSuite) Test_CheckToken_When_Invalid_Issuer_Value_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("invalid issuer", err.Error())
}
Expand All @@ -240,7 +260,7 @@ func (s *AuthSuite) Test_CheckToken_When_Audience_Missing_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("claim \"aud\" not found", err.Error())
}
Expand All @@ -255,7 +275,7 @@ func (s *AuthSuite) Test_CheckToken_When_Audience_Invalid_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Equal("\"aud\" not satisfied", err.Error())
}
Expand All @@ -271,7 +291,7 @@ func (s *AuthSuite) Test_CheckToken_When_Valid_No_DPoP_Expect_Error() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().Error(err)
s.Require().Contains(err.Error(), "dpop")
}
Expand Down Expand Up @@ -348,15 +368,15 @@ func (s *AuthSuite) TestInvalid_DPoP_Cases() {
_, _, err = s.auth.checkToken(
context.Background(),
[]string{fmt.Sprintf("DPoP %s", string(testCase.accessToken))},
dpopInfo{
headers: []string{dpopToken},
path: "/a/path",
method: http.MethodPost,
receiverInfo{
u: "/a/path",
m: http.MethodPost,
},
[]string{dpopToken},
)

s.Require().Error(err)
s.Equal(testCase.errorMessage, err.Error())
s.Contains(err.Error(), testCase.errorMessage)
}
}

Expand Down Expand Up @@ -567,7 +587,7 @@ func (s *AuthSuite) Test_Allowing_Auth_With_No_DPoP() {
s.NotNil(signedTok)
s.Require().NoError(err)

_, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
_, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
s.Require().NoError(err)
s.Require().Nil(GetJWKFromContext(ctx))
}
Expand Down