Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5c19f24
feat(sdk): pass the dpop through
mkleene Mar 19, 2024
294016d
make this still work if the key is not passed
mkleene Mar 19, 2024
3633163
lint and encapsulate getting the key from the context
mkleene Mar 19, 2024
00dcd28
just use the context, no HTTP headers
mkleene Mar 19, 2024
2b86cfd
lint
mkleene Mar 19, 2024
d956bf1
lint
mkleene Mar 20, 2024
bd9c2f8
Merge remote-tracking branch 'origin/main' into pass-dpop-token-through
mkleene Mar 20, 2024
763b435
fix rewrap test
mkleene Mar 20, 2024
c01cd70
fix rewrap test
mkleene Mar 20, 2024
d499c80
use the right padding
mkleene Mar 20, 2024
761bf7a
bad merge?
mkleene Mar 20, 2024
bc35b2e
try using that keycloak
mkleene Mar 20, 2024
47b8a54
bump
mkleene Mar 20, 2024
c49294f
enable auth and provision protocol mapper
mkleene Mar 20, 2024
1733df0
allow the endpoints that do not require auth
mkleene Mar 20, 2024
c053c0f
Merge branch 'main' into pass-dpop-token-through
mkleene Mar 20, 2024
0a9095e
Update docker-compose.yaml
mkleene Mar 20, 2024
717e4d3
Merge branch 'main' into pass-dpop-token-through
mkleene Mar 21, 2024
e357157
Merge remote-tracking branch 'origin/main' into pass-dpop-token-through
mkleene Mar 21, 2024
ccaff46
fix tests
mkleene Mar 21, 2024
5b80b12
fix tests
mkleene Mar 21, 2024
ab0a709
fix import
mkleene Mar 21, 2024
a84ed60
sort out permissions
mkleene Mar 21, 2024
22312c8
allow unauthenticated rewraps
mkleene Mar 21, 2024
f45e635
lint
mkleene Mar 21, 2024
b1d30dd
make this work with auth disabled
mkleene Mar 22, 2024
03fdc15
Merge branch 'main' into pass-dpop-token-through
mkleene Mar 22, 2024
cb4079b
ignore this for now
mkleene Mar 25, 2024
bfa8088
Merge remote-tracking branch 'origin/main' into pass-dpop-token-through
mkleene Mar 25, 2024
c9246ba
use the right docker compose
mkleene Mar 25, 2024
792a534
add the sdk
mkleene Mar 25, 2024
a88dd11
use a constant instead
mkleene Mar 25, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/checks.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ jobs:
- run: go install github.com/fullstorydev/grpcurl/cmd/[email protected]
- run: grpcurl -plaintext localhost:8080 list
- run: grpcurl -plaintext localhost:8080 grpc.health.v1.Health.Check
- run: grpcurl -plaintext localhost:8080 kas.AccessService/PublicKey
- run: curl --show-error --fail --insecure localhost:8080/kas/v2/kas_public_key
- run: go run examples/main.go encrypt "Hello Virtru"
- run: go run examples/main.go decrypt sensitive.txt.tdf
Expand Down
12 changes: 12 additions & 0 deletions cmd/provisionKeyloak.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,18 @@ var (
"id.token.claim": "true",
},
},
{
Name: gocloak.StringP("dpop-mapper"),
Protocol: gocloak.StringP("openid-connect"),
ProtocolMapper: gocloak.StringP("virtru-oidc-protocolmapper"),
Config: &map[string]string{
"claim.name": "tdf_claims",
"client.dpop": "true",
"tdf_claims.enabled": "true",
"access.token.claim": "true",
"client.publickey": "X-VirtruPubKey",
},
},
}

// Create Roles
Expand Down
3 changes: 2 additions & 1 deletion example-opentdf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ services:
- "msExchMailboxSecurityDescriptor"
server:
auth:
enabled: false
enabled: true
audience: "http://localhost:8080"
issuer: http://localhost:8888/auth/realms/opentdf
clients:
- "opentdf"
- "opentdf-sdk"
policy:
## Default policy for all requests
default: #"role:readonly"
Expand Down
2 changes: 1 addition & 1 deletion examples/cmd/decrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func decrypt(cmd *cobra.Command, args []string) error {
// Create new client
client, err := sdk.New(cmd.Context().Value(RootConfigKey).(*ExampleConfig).PlatformEndpoint,
sdk.WithInsecureConn(),
sdk.WithClientCredentials("opentdf", "secret", nil),
sdk.WithClientCredentials("opentdf-sdk", "secret", nil),
sdk.WithTokenEndpoint("http://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token"),
)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion examples/cmd/encrypt.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func encrypt(cmd *cobra.Command, args []string) error {

client, err := sdk.New(cmd.Context().Value(RootConfigKey).(*ExampleConfig).PlatformEndpoint,
sdk.WithInsecureConn(),
sdk.WithClientCredentials("opentdf", "secret", nil),
sdk.WithClientCredentials("opentdf-sdk", "secret", nil),
sdk.WithTokenEndpoint("http://localhost:8888/auth/realms/opentdf/protocol/openid-connect/token"),
)
if err != nil {
Expand Down
113 changes: 70 additions & 43 deletions internal/auth/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,24 @@ import (
"google.golang.org/grpc/status"
)

const (
dpopJWKContextKey = authContextKey("dpop-jwk")
)

type authContextKey string

var (
// Set of allowed gRPC endpoints that do not require authentication
allowedGRPCEndpoints = [...]string{
"/grpc.health.v1.Health/Check",
"/wellknownconfiguration.WellKnownService/GetWellKnownConfiguration",
"/kas.AccessService/PublicKey",
}
// Set of allowed HTTP endpoints that do not require authentication
allowedHTTPEndpoints = [...]string{
"/healthz",
"/.well-known/opentdf-configuration",
"/kas/v2/kas_public_key",
}
// only asymmetric algorithms and no 'none'
allowedSignatureAlgorithms = map[jwa.SignatureAlgorithm]bool{
Expand Down Expand Up @@ -121,14 +129,14 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
handler.ServeHTTP(w, r)
return
}

// Verify the token
header := r.Header["Authorization"]
if len(header) < 1 {
http.Error(w, "missing authorization header", http.StatusUnauthorized)
return
}

tok, err := a.checkToken(r.Context(), header, dpopInfo{
tok, dpopKey, err := a.checkToken(r.Context(), header, dpopInfo{
headers: r.Header["Dpop"],
path: r.URL.Path,
method: r.Method,
Expand Down Expand Up @@ -166,7 +174,7 @@ func (a Authentication) MuxHandler(handler http.Handler) http.Handler {
return
}

handler.ServeHTTP(w, r)
handler.ServeHTTP(w, r.WithContext(ContextWithJWK(r.Context(), dpopKey)))
})
}

Expand Down Expand Up @@ -207,7 +215,7 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
action = "other"
}

token, err := a.checkToken(
token, dpopJWK, err := a.checkToken(
ctx,
header,
dpopInfo{
Expand All @@ -233,11 +241,11 @@ func (a Authentication) UnaryServerInterceptor(ctx context.Context, req any, inf
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
}

return handler(ctx, req)
return handler(ContextWithJWK(ctx, dpopJWK), req)
}

// checkToken is a helper function to verify the token.
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo dpopInfo) (jwt.Token, error) {
func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpopInfo dpopInfo) (jwt.Token, jwk.Key, error) {
var (
tokenRaw string
tokenType string
Expand All @@ -252,33 +260,33 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo
tokenType = "Bearer"
tokenRaw = strings.TrimPrefix(authHeader[0], "Bearer ")
default:
return nil, fmt.Errorf("not of type bearer or dpop")
return nil, nil, fmt.Errorf("not of type bearer or dpop")
}

// We have to get iss from the token first to verify the signature
unverifiedToken, err := jwt.Parse([]byte(tokenRaw), jwt.WithVerify(false))
if err != nil {
return nil, err
return nil, nil, err
}

// Get issuer from unverified token
issuer := unverifiedToken.Issuer()
if issuer == "" {
return nil, fmt.Errorf("missing issuer")
return nil, nil, fmt.Errorf("missing issuer")
}

// Get the openid configuration for the issuer
// Because we get an interface we need to cast it to a string
// and jwx expects it as a string so we should never hit this error if the token is valid
oidc, exists := a.oidcConfigurations[issuer]
if !exists {
return nil, fmt.Errorf("invalid issuer")
return nil, nil, fmt.Errorf("invalid issuer")
}

// Get key set from cache that matches the jwks_uri
keySet, err := a.cache.Get(ctx, oidc.JwksURI)
if err != nil {
return nil, fmt.Errorf("failed to get jwks from cache")
return nil, nil, fmt.Errorf("failed to get jwks from cache")
}

// Now we verify the token signature
Expand All @@ -291,85 +299,104 @@ func (a Authentication) checkToken(ctx context.Context, authHeader []string, dpo
)

if err != nil {
return nil, err
return nil, nil, err
}

if tokenType == "Bearer" {
if _, ok := accessToken.Get("cnf"); !ok {
return accessToken, nil
}
slog.Info("presented token with `cnf` claim as a bearer token. validating as DPoP")
slog.Warn("Presented bearer token. validating as DPoP")
}

key, err := validateDPoP(accessToken, tokenRaw, dpopInfo)
if err != nil {
return nil, nil, err
}

return accessToken, validateDPoP(accessToken, tokenRaw, dpopInfo)
return accessToken, *key, nil
}

func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo) error {
func ContextWithJWK(ctx context.Context, key jwk.Key) context.Context {
return context.WithValue(ctx, dpopJWKContextKey, key)
}

func GetJWKFromContext(ctx context.Context) jwk.Key {
key := ctx.Value(dpopJWKContextKey)
if key == nil {
return nil
}
if jwk, ok := key.(jwk.Key); ok {
return jwk
}

return nil
}

func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo) (*jwk.Key, error) {
if len(dpopInfo.headers) != 1 {
return fmt.Errorf("got %d dpop headers, should have 1", len(dpopInfo.headers))
return nil, fmt.Errorf("got %d dpop headers, should have 1", len(dpopInfo.headers))
}
dpopHeader := dpopInfo.headers[0]

cnf, ok := accessToken.Get("cnf")
if !ok {
return fmt.Errorf("missing `cnf` claim in access token")
return nil, fmt.Errorf("missing `cnf` claim in access token")
}

cnfDict, ok := cnf.(map[string]interface{})
if !ok {
return fmt.Errorf("got `cnf` in an invalid format")
return nil, fmt.Errorf("got `cnf` in an invalid format")
}

jktI, ok := cnfDict["jkt"]
if !ok {
return fmt.Errorf("missing `jkt` field in `cnf` claim. only thumbprint JWK confirmation is supported")
return nil, fmt.Errorf("missing `jkt` field in `cnf` claim. only thumbprint JWK confirmation is supported")
}

jkt, ok := jktI.(string)
if !ok {
return fmt.Errorf("invalid `jkt` field in `cnf` claim: %v. the value must be a JWK thumbprint", jkt)
return nil, fmt.Errorf("invalid `jkt` field in `cnf` claim: %v. the value must be a JWK thumbprint", jkt)
}

dpop, err := jws.Parse([]byte(dpopHeader))
if err != nil {
slog.Error("error parsing JWT: %w", err)
return fmt.Errorf("invalid DPoP JWT")
return nil, fmt.Errorf("invalid DPoP JWT")
}
if len(dpop.Signatures()) != 1 {
return fmt.Errorf("expected one signature on DPoP JWT, got %d", len(dpop.Signatures()))
return nil, fmt.Errorf("expected one signature on DPoP JWT, got %d", len(dpop.Signatures()))
}
sig := dpop.Signatures()[0]
protectedHeaders := sig.ProtectedHeaders()
if protectedHeaders.Type() != "dpop+jwt" {
return fmt.Errorf("invalid typ on DPoP JWT: %v", protectedHeaders.Type())
return nil, fmt.Errorf("invalid typ on DPoP JWT: %v", protectedHeaders.Type())
}

if _, exists := allowedSignatureAlgorithms[protectedHeaders.Algorithm()]; !exists {
return fmt.Errorf("unsupported algorithm specified: %v", protectedHeaders.Algorithm())
return nil, fmt.Errorf("unsupported algorithm specified: %v", protectedHeaders.Algorithm())
}

dpopKey := protectedHeaders.JWK()
if dpopKey == nil {
return fmt.Errorf("JWK missing in DPoP JWT")
return nil, fmt.Errorf("JWK missing in DPoP JWT")
}

isPrivate, err := jwk.IsPrivateKey(dpopKey)
if err != nil {
slog.Error("error checking if key is private", err)
return fmt.Errorf("invalid DPoP key specified")
return nil, fmt.Errorf("invalid DPoP key specified")
}

if isPrivate {
return fmt.Errorf("cannot use a private key for DPoP")
return nil, fmt.Errorf("cannot use a private key for DPoP")
}

thumbprint, err := dpopKey.Thumbprint(crypto.SHA256)
if err != nil {
slog.Error("error computing thumbprint for key", err)
return fmt.Errorf("couldn't compute thumbprint for key in `jwk` in DPoP JWT")
return nil, fmt.Errorf("couldn't compute thumbprint for key in `jwk` in DPoP JWT")
}

if base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(thumbprint) != jkt {
return fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token")
return nil, fmt.Errorf("the `jkt` from the DPoP JWT didn't match the thumbprint from the access token")
}

// at this point we have the right key because its thumbprint matches the `jkt` claim
Expand All @@ -378,47 +405,47 @@ func validateDPoP(accessToken jwt.Token, acessTokenRaw string, dpopInfo dpopInfo

if err != nil {
slog.Error("error validating DPoP JWT", err)
return fmt.Errorf("failed to verify signature on DPoP JWT")
return nil, fmt.Errorf("failed to verify signature on DPoP JWT")
}

issuedAt := dpopToken.IssuedAt()
if issuedAt.IsZero() {
return fmt.Errorf("missing `iat` claim in the DPoP JWT")
return nil, fmt.Errorf("missing `iat` claim in the DPoP JWT")
}

if issuedAt.Add(time.Hour).Before(time.Now()) {
return fmt.Errorf("the DPoP JWT has expired")
return nil, fmt.Errorf("the DPoP JWT has expired")
}

htm, ok := dpopToken.Get("htm")
if !ok {
return fmt.Errorf("`htm` claim missing in DPoP JWT")
return nil, fmt.Errorf("`htm` claim missing in DPoP JWT")
}

if htm != dpopInfo.method {
return fmt.Errorf("incorrect `htm` claim in DPoP JWT")
return nil, fmt.Errorf("incorrect `htm` claim in DPoP JWT")
}

htu, ok := dpopToken.Get("htu")
if !ok {
return fmt.Errorf("`htu` claim missing in DPoP JWT")
return nil, fmt.Errorf("`htu` claim missing in DPoP JWT")
}

if htu != dpopInfo.path {
return fmt.Errorf("incorrect `htu` claim in DPoP JWT")
return nil, fmt.Errorf("incorrect `htu` claim in DPoP JWT")
}

ath, ok := dpopToken.Get("ath")
if !ok {
return fmt.Errorf("missing `ath` claim in DPoP JWT")
return nil, fmt.Errorf("missing `ath` claim in DPoP JWT")
}

h := sha256.New()
h.Write([]byte(acessTokenRaw))
if ath != base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(h.Sum(nil)) {
return fmt.Errorf("incorrect `ath` claim in DPoP JWT")
return nil, fmt.Errorf("incorrect `ath` claim in DPoP JWT")
}

return nil
return &dpopKey, nil
}

// claimsValidator is a custom validator to check extra claims in the token.
Expand Down
Loading