Skip to content

Commit 10b81e5

Browse files
chore(auth): DPoP and public fixes (#651)
- expose legacy public key endpoint as public - DPoP `htu` should include origin part of url - clarify error messages for dpop
1 parent 19d9bfe commit 10b81e5

File tree

2 files changed

+72
-38
lines changed

2 files changed

+72
-38
lines changed

service/internal/auth/authn.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"fmt"
99
"log/slog"
1010
"net/http"
11+
"net/url"
1112
"path/filepath"
1213
"slices"
1314
"strings"
@@ -37,6 +38,7 @@ var (
3738
"/kas.AccessService/PublicKey",
3839
"/healthz",
3940
"/.well-known/opentdf-configuration",
41+
"/kas/kas_public_key",
4042
"/kas/v2/kas_public_key",
4143
}
4244
// only asymmetric algorithms and no 'none'
@@ -125,10 +127,21 @@ func NewAuthenticator(ctx context.Context, cfg Config) (*Authentication, error)
125127
return a, nil
126128
}
127129

128-
type dpopInfo struct {
129-
headers []string
130-
path string
131-
method string
130+
type receiverInfo struct {
131+
// The URI of the request
132+
u string
133+
// The HTTP method of the request
134+
m string
135+
}
136+
137+
func normalizeURL(o string, u *url.URL) string {
138+
// Currently this does not do a full normatlization
139+
ou, err := url.Parse(o)
140+
if err != nil {
141+
return u.String()
142+
}
143+
ou.Path = u.Path
144+
return ou.String()
132145
}
133146

134147
// verifyTokenHandler is a http handler that verifies the token
@@ -145,11 +158,11 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
145158
http.Error(w, "missing authorization header", http.StatusUnauthorized)
146159
return
147160
}
148-
tok, newCtx, err := a.checkToken(r.Context(), header, dpopInfo{
149-
headers: r.Header["Dpop"],
150-
path: r.URL.Path,
151-
method: r.Method,
152-
})
161+
origin := r.Header.Get("Origin")
162+
tok, newCtx, err := a.checkToken(r.Context(), header, receiverInfo{
163+
u: normalizeURL(origin, r.URL),
164+
m: r.Method,
165+
}, r.Header["Dpop"])
153166

154167
if err != nil {
155168
slog.WarnContext(r.Context(), "failed to validate token", slog.String("error", err.Error()))
@@ -227,11 +240,11 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
227240
token, newCtx, err := a.checkToken(
228241
ctx,
229242
header,
230-
dpopInfo{
231-
headers: md["dpop"],
232-
path: info.FullMethod,
233-
method: http.MethodPost,
243+
receiverInfo{
244+
u: info.FullMethod,
245+
m: http.MethodPost,
234246
},
247+
md["dpop"],
235248
)
236249
if err != nil {
237250
slog.Warn("failed to validate token", slog.String("error", err.Error()))
@@ -254,7 +267,7 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
254267
}
255268

256269
// checkToken is a helper function to verify the token.
257-
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo dpopInfo) (jwt.Token, context.Context, error) {
270+
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo receiverInfo, dpopHeader []string) (jwt.Token, context.Context, error) {
258271
var (
259272
tokenRaw string
260273
)
@@ -313,7 +326,7 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo
313326
// come from token introspection
314327
return accessToken, ctx, nil
315328
}
316-
key, err := validateDPoP(accessToken, tokenRaw, dpopInfo)
329+
key, err := validateDPoP(accessToken, tokenRaw, dpopInfo, dpopHeader)
317330
if err != nil {
318331
return nil, nil, err
319332
}
@@ -336,11 +349,11 @@ func GetJWKFromContext(ctx context.Context) jwk.Key {
336349
panic("got something that is not a jwk.Key from the JWK context")
337350
}
338351

339-
func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo) (jwk.Key, error) {
340-
if len(dpopInfo.headers) != 1 {
341-
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(dpopInfo.headers))
352+
func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo receiverInfo, headers []string) (jwk.Key, error) {
353+
if len(headers) != 1 {
354+
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(headers))
342355
}
343-
dpopHeader := dpopInfo.headers[0]
356+
dpopHeader := headers[0]
344357

345358
cnf, ok := accessToken.Get("cnf")
346359
if !ok {
@@ -401,8 +414,9 @@ func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo
401414
return nil, fmt.Errorf("couldn't compute thumbprint for key in `jwk` in DPoP JWT")
402415
}
403416

404-
if base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(thumbprint) != jkt {
405-
return nil, fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token")
417+
thumbprintStr := base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(thumbprint)
418+
if thumbprintStr != jkt {
419+
return nil, fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token; cnf.jkt=[%v], computed=[%v]", jkt, thumbprintStr)
406420
}
407421

408422
// at this point we have the right key because its thumbprint matches the `jkt` claim
@@ -428,17 +442,17 @@ func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo
428442
return nil, fmt.Errorf("`htm` claim missing in DPoP JWT")
429443
}
430444

431-
if htm != dpopInfo.method {
432-
return nil, fmt.Errorf("incorrect `htm` claim in DPoP JWT")
445+
if htm != dpopInfo.m {
446+
return nil, fmt.Errorf("incorrect `htm` claim in DPoP JWT; should match [%v]", dpopInfo.m)
433447
}
434448

435449
htu, ok := dpopToken.Get("htu")
436450
if !ok {
437451
return nil, fmt.Errorf("`htu` claim missing in DPoP JWT")
438452
}
439453

440-
if htu != dpopInfo.path {
441-
return nil, fmt.Errorf("incorrect `htu` claim in DPoP JWT")
454+
if htu != dpopInfo.u {
455+
return nil, fmt.Errorf("incorrect `htu` claim in DPoP JWT; should match %v", dpopInfo.u)
442456
}
443457

444458
ath, ok := dpopToken.Get("ath")

service/internal/auth/authn_test.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"net"
1515
"net/http"
1616
"net/http/httptest"
17+
"net/url"
1718
"slices"
1819
"testing"
1920
"time"
@@ -24,6 +25,7 @@ import (
2425
"github.com/lestrrat-go/jwx/v2/jwt"
2526
"github.com/opentdf/platform/protocol/go/kas"
2627
sdkauth "github.com/opentdf/platform/sdk/auth"
28+
"github.com/stretchr/testify/require"
2729
"github.com/stretchr/testify/suite"
2830
"google.golang.org/grpc"
2931
"google.golang.org/grpc/codes"
@@ -32,6 +34,7 @@ import (
3234
"google.golang.org/grpc/status"
3335
"google.golang.org/grpc/test/bufconn"
3436
"google.golang.org/protobuf/types/known/wrapperspb"
37+
"gotest.tools/v3/assert"
3538
)
3639

3740
type AuthSuite struct {
@@ -164,6 +167,23 @@ func TestAuthSuite(t *testing.T) {
164167
suite.Run(t, new(AuthSuite))
165168
}
166169

170+
func TestNormalizeUrl(t *testing.T) {
171+
for _, tt := range []struct {
172+
origin, path, out string
173+
}{
174+
{"http://localhost", "/", "http://localhost/"},
175+
{"https://localhost", "/somewhere", "https://localhost/somewhere"},
176+
{"http://localhost", "", "http://localhost"},
177+
} {
178+
t.Run(tt.origin+tt.path, func(t *testing.T) {
179+
u, err := url.Parse(tt.path)
180+
require.NoError(t, err)
181+
s := normalizeURL(tt.origin, u)
182+
assert.Equal(t, s, tt.out)
183+
})
184+
}
185+
}
186+
167187
func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() {
168188
tok := jwt.New()
169189
s.Require().NoError(tok.Set(jwt.ExpirationKey, time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)))
@@ -173,7 +193,7 @@ func (s *AuthSuite) Test_CheckToken_When_JWT_Expired_Expect_Error() {
173193
s.NotNil(signedTok)
174194
s.Require().NoError(err)
175195

176-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
196+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
177197
s.Require().Error(err)
178198
s.Equal("\"exp\" not satisfied", err.Error())
179199
}
@@ -197,7 +217,7 @@ func (s *AuthSuite) Test_UnaryServerInterceptor_When_Authorization_Header_Missin
197217
}
198218

199219
func (s *AuthSuite) Test_CheckToken_When_Authorization_Header_Invalid_Expect_Error() {
200-
_, _, err := s.auth.checkToken(context.Background(), []string{"BPOP "}, dpopInfo{})
220+
_, _, err := s.auth.checkToken(context.Background(), []string{"BPOP "}, receiverInfo{}, nil)
201221
s.Require().Error(err)
202222
s.Equal("not of type bearer or dpop", err.Error())
203223
}
@@ -211,7 +231,7 @@ func (s *AuthSuite) Test_CheckToken_When_Missing_Issuer_Expect_Error() {
211231
s.NotNil(signedTok)
212232
s.Require().NoError(err)
213233

214-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
234+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
215235
s.Require().Error(err)
216236
s.Equal("missing issuer", err.Error())
217237
}
@@ -226,7 +246,7 @@ func (s *AuthSuite) Test_CheckToken_When_Invalid_Issuer_Value_Expect_Error() {
226246
s.NotNil(signedTok)
227247
s.Require().NoError(err)
228248

229-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
249+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
230250
s.Require().Error(err)
231251
s.Equal("invalid issuer", err.Error())
232252
}
@@ -240,7 +260,7 @@ func (s *AuthSuite) Test_CheckToken_When_Audience_Missing_Expect_Error() {
240260
s.NotNil(signedTok)
241261
s.Require().NoError(err)
242262

243-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
263+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
244264
s.Require().Error(err)
245265
s.Equal("claim \"aud\" not found", err.Error())
246266
}
@@ -255,7 +275,7 @@ func (s *AuthSuite) Test_CheckToken_When_Audience_Invalid_Expect_Error() {
255275
s.NotNil(signedTok)
256276
s.Require().NoError(err)
257277

258-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
278+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
259279
s.Require().Error(err)
260280
s.Equal("\"aud\" not satisfied", err.Error())
261281
}
@@ -271,7 +291,7 @@ func (s *AuthSuite) Test_CheckToken_When_Valid_No_DPoP_Expect_Error() {
271291
s.NotNil(signedTok)
272292
s.Require().NoError(err)
273293

274-
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
294+
_, _, err = s.auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
275295
s.Require().Error(err)
276296
s.Require().Contains(err.Error(), "dpop")
277297
}
@@ -348,15 +368,15 @@ func (s *AuthSuite) TestInvalid_DPoP_Cases() {
348368
_, _, err = s.auth.checkToken(
349369
context.Background(),
350370
[]string{fmt.Sprintf("DPoP %s", string(testCase.accessToken))},
351-
dpopInfo{
352-
headers: []string{dpopToken},
353-
path: "/a/path",
354-
method: http.MethodPost,
371+
receiverInfo{
372+
u: "/a/path",
373+
m: http.MethodPost,
355374
},
375+
[]string{dpopToken},
356376
)
357377

358378
s.Require().Error(err)
359-
s.Equal(testCase.errorMessage, err.Error())
379+
s.Contains(err.Error(), testCase.errorMessage)
360380
}
361381
}
362382

@@ -567,7 +587,7 @@ func (s *AuthSuite) Test_Allowing_Auth_With_No_DPoP() {
567587
s.NotNil(signedTok)
568588
s.Require().NoError(err)
569589

570-
_, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, dpopInfo{})
590+
_, ctx, err := auth.checkToken(context.Background(), []string{fmt.Sprintf("Bearer %s", string(signedTok))}, receiverInfo{}, nil)
571591
s.Require().NoError(err)
572592
s.Require().Nil(GetJWKFromContext(ctx))
573593
}

0 commit comments

Comments
 (0)