diff --git a/api/utils/iterutils/iter.go b/api/utils/iterutils/iter.go index 3d0ddf495f43c..91d68bb6741a7 100644 --- a/api/utils/iterutils/iter.go +++ b/api/utils/iterutils/iter.go @@ -35,3 +35,18 @@ func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] { } } } + +// Filter returns an iterator over seq that only includes the values v for which +// f(v) is true. +// +// Copied from https://github.com/golang/go/issues/61898. We should switch to an +// official package once it is available. +func Filter[V any](f func(V) bool, seq iter.Seq[V]) iter.Seq[V] { + return func(yield func(V) bool) { + for v := range seq { + if f(v) && !yield(v) { + return + } + } + } +} diff --git a/api/utils/iterutils/iter_test.go b/api/utils/iterutils/iter_test.go index e6bf3e3feecd2..763fe63ceccbe 100644 --- a/api/utils/iterutils/iter_test.go +++ b/api/utils/iterutils/iter_test.go @@ -37,3 +37,21 @@ func ExampleMap() { // HELLO WORLD // FOO } + +func ExampleFilter() { + inputs := []string{ + "a", + "bb", + "ccc", + "dddd", + } + isOddLen := func(s string) bool { + return len(s)%2 == 1 + } + for filtered := range Filter(isOddLen, slices.Values(inputs)) { + fmt.Println(filtered) + } + // Output: + // a + // ccc +} diff --git a/lib/services/app.go b/lib/services/app.go index 096d136000b8c..e0d0986e73487 100644 --- a/lib/services/app.go +++ b/lib/services/app.go @@ -22,7 +22,9 @@ import ( "context" "fmt" "iter" + "log/slog" "net" + "net/http" "net/url" "os" "strconv" @@ -34,7 +36,9 @@ import ( "k8s.io/apimachinery/pkg/util/validation" kyaml "k8s.io/apimachinery/pkg/util/yaml" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/utils" ) @@ -310,3 +314,27 @@ func getClusterDomain() string { } return "cluster.local" } + +// RewriteHeadersAndApplyValueTraits rewrites the provided request's headers +// while applying value traits to them. +func RewriteHeadersAndApplyValueTraits(r *http.Request, rewrites iter.Seq[*types.Header], traits wrappers.Traits, log *slog.Logger) { + for header := range rewrites { + values, err := ApplyValueTraits(header.Value, traits) + if err != nil { + log.DebugContext(r.Context(), "Failed to apply traits", + "header_value", header.Value, + "error", err, + ) + continue + } + r.Header.Del(header.Name) + for _, value := range values { + switch http.CanonicalHeaderKey(header.Name) { + case teleport.HostHeader: + r.Host = value + default: + r.Header.Add(header.Name, value) + } + } + } +} diff --git a/lib/services/app_test.go b/lib/services/app_test.go index 2dadd3fdb6cf6..a9e7334866279 100644 --- a/lib/services/app_test.go +++ b/lib/services/app_test.go @@ -19,8 +19,13 @@ package services import ( + "log/slog" + "net/http" + "net/http/httptest" + "slices" "testing" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -375,3 +380,25 @@ func TestInsecureSkipVerify(t *testing.T) { require.Equal(t, tt.expected, result) } } + +func TestRewriteHeadersAndApplyValueTraits(t *testing.T) { + r := httptest.NewRequest("GET", "/foo", nil) + r.Header.Set("x-no-rewrite", "no-rewrite") + rewrites := []*types.Header{ + {Name: "host", Value: "1.2.3.4"}, + {Name: "x-rewrite", Value: "{{external.rewrite}}"}, + // Missing traits should log a debug message that this rewrite is skipped. + {Name: "x-bad-rewrite", Value: "{{external.bad_rewrite}}"}, + } + traits := map[string][]string{ + "rewrite": {"value1", "value2"}, + } + RewriteHeadersAndApplyValueTraits(r, slices.Values(rewrites), traits, slog.Default()) + + assert.Equal(t, "1.2.3.4", r.Host) + wantHeaders := make(http.Header) + wantHeaders.Add("x-rewrite", "value1") + wantHeaders.Add("x-rewrite", "value2") + wantHeaders.Add("x-no-rewrite", "no-rewrite") + assert.Equal(t, wantHeaders, r.Header) +} diff --git a/lib/srv/app/common/header_rewriter.go b/lib/srv/app/common/header_rewriter.go index 707e16ea505c3..de2251611de57 100644 --- a/lib/srv/app/common/header_rewriter.go +++ b/lib/srv/app/common/header_rewriter.go @@ -19,8 +19,14 @@ package common import ( + "context" + "iter" + "log/slog" "net/http/httputil" + "slices" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/utils/iterutils" "github.com/gravitational/teleport/lib/httplib/reverseproxy" ) @@ -55,3 +61,19 @@ func (hr *HeaderRewriter) Rewrite(req *httputil.ProxyRequest) { req.Out.Header.Set(XForwardedSSL, sslOff) } } + +// AppRewriteHeaders returns an iterator for app headers to rewrite. Reserved +// headers are skipped. +func AppRewriteHeaders(ctx context.Context, rewrite *types.Rewrite, log *slog.Logger) iter.Seq[*types.Header] { + var headers []*types.Header + if rewrite != nil { + headers = rewrite.Headers + } + return iterutils.Filter(func(header *types.Header) bool { + if IsReservedHeader(header.Name) { + log.DebugContext(ctx, "Not rewriting Teleport reserved header", "header_name", header.Name) + return false + } + return true + }, slices.Values(headers)) +} diff --git a/lib/srv/app/common/header_rewriter_test.go b/lib/srv/app/common/header_rewriter_test.go index ff349d52aaf91..ccb166dde1426 100644 --- a/lib/srv/app/common/header_rewriter_test.go +++ b/lib/srv/app/common/header_rewriter_test.go @@ -19,15 +19,20 @@ package common import ( + "context" "crypto/tls" "fmt" + "log/slog" "net/http" "net/http/httputil" "net/url" + "slices" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/httplib/reverseproxy" ) @@ -148,3 +153,39 @@ func TestHeaderRewriter(t *testing.T) { }) } } + +func TestAppRewriteHeaders(t *testing.T) { + tests := []struct { + name string + rewrite *types.Rewrite + wantHeaders []*types.Header + }{ + { + name: "no rewrite", + rewrite: nil, + wantHeaders: nil, + }, + { + name: "reserved header is filtered", + rewrite: &types.Rewrite{ + Headers: []*types.Header{ + {Name: "test-key-1", Value: "test-value-1"}, + {Name: "teleport-jwt-assertion", Value: "teleport-jwt-assertion-value"}, + {Name: "test-key-2", Value: "test-value-2"}, + {Name: "X-Real-Ip", Value: "1.2.3.4"}, + }, + }, + wantHeaders: []*types.Header{ + {Name: "test-key-1", Value: "test-value-1"}, + {Name: "test-key-2", Value: "test-value-2"}, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + actualHeaders := AppRewriteHeaders(context.Background(), test.rewrite, slog.Default()) + require.Equal(t, test.wantHeaders, slices.Collect(actualHeaders)) + }) + } +} diff --git a/lib/srv/app/common/jwt.go b/lib/srv/app/common/jwt.go new file mode 100644 index 0000000000000..77f9fdf0a4301 --- /dev/null +++ b/lib/srv/app/common/jwt.go @@ -0,0 +1,78 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package common + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/constants" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/api/types/wrappers" + "github.com/gravitational/teleport/lib/tlsca" +) + +// AppTokenGenerator defines an interface for generating JWT token for an +// application (by auth). +type AppTokenGenerator interface { + GenerateAppToken(context.Context, types.GenerateAppTokenRequest) (string, error) +} + +// GenerateJWTAndTraits is helper that generates a JWT for an application and +// populates the user traits with the result JWT for templating. +func GenerateJWTAndTraits( + ctx context.Context, + identity *tlsca.Identity, + app types.Application, + generator AppTokenGenerator, +) (string, wrappers.Traits, error) { + rewrite := app.GetRewrite() + traits := identity.Traits + roles := identity.Groups + if rewrite != nil { + switch rewrite.JWTClaims { + case types.JWTClaimsRewriteNone: + traits = nil + roles = nil + case types.JWTClaimsRewriteRoles: + traits = nil + case types.JWTClaimsRewriteTraits: + roles = nil + case "", types.JWTClaimsRewriteRolesAndTraits: + } + } + + // Request a JWT token that will be attached to all requests. + jwt, err := generator.GenerateAppToken(ctx, types.GenerateAppTokenRequest{ + Username: identity.Username, + Roles: roles, + Traits: traits, + URI: app.GetURI(), + Expires: identity.Expires, + }) + if err != nil { + return "", nil, trace.Wrap(err) + } + if traits == nil { + traits = make(wrappers.Traits) + } + traits[constants.TraitJWT] = []string{jwt} + return jwt, traits, trace.Wrap(err) +} diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index a9ab09c3c9e38..068cc6cdb7c19 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -30,10 +30,8 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/constants" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/recorder" "github.com/gravitational/teleport/lib/httplib/reverseproxy" @@ -146,40 +144,11 @@ func (c *ConnectionsHandler) newSessionChunk(ctx context.Context, identity *tlsc // withJWTTokenForwarder is a sessionOpt that creates a forwarder that attaches // a generated JWT token to all requests. func (c *ConnectionsHandler) withJWTTokenForwarder(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { - rewrite := app.GetRewrite() - traits := identity.Traits - roles := identity.Groups - if rewrite != nil { - switch rewrite.JWTClaims { - case types.JWTClaimsRewriteNone: - traits = nil - roles = nil - case types.JWTClaimsRewriteRoles: - traits = nil - case types.JWTClaimsRewriteTraits: - roles = nil - case "", types.JWTClaimsRewriteRolesAndTraits: - } - } - - // Request a JWT token that will be attached to all requests. - jwt, err := c.cfg.AuthClient.GenerateAppToken(ctx, types.GenerateAppTokenRequest{ - Username: identity.Username, - Roles: roles, - Traits: traits, - URI: app.GetURI(), - Expires: identity.Expires, - }) + jwt, traits, err := common.GenerateJWTAndTraits(ctx, identity, app, c.cfg.AuthClient) if err != nil { return trace.Wrap(err) } - // Add JWT token to the traits so it can be used in headers templating. - if traits == nil { - traits = make(wrappers.Traits) - } - traits[constants.TraitJWT] = []string{jwt} - // Create a rewriting transport that will be used to forward requests. transport, err := newTransport(c.closeContext, &transportConfig{ diff --git a/lib/srv/app/transport.go b/lib/srv/app/transport.go index c970e22ad8c22..a06d4cfe54817 100644 --- a/lib/srv/app/transport.go +++ b/lib/srv/app/transport.go @@ -198,45 +198,14 @@ func (t *transport) rewriteRequest(r *http.Request) error { r.URL.Scheme = t.uri.Scheme r.URL.Host = t.uri.Host + // Add in JWT headers. + r.Header.Set(teleport.AppJWTHeader, t.jwt) // Add headers from rewrite configuration. - rewriteHeaders(r, t.transportConfig) - + rewriteHeaders := common.AppRewriteHeaders(r.Context(), t.app.GetRewrite(), t.log) + services.RewriteHeadersAndApplyValueTraits(r, rewriteHeaders, t.traits, t.log) return nil } -// rewriteHeaders applies headers rewrites from the application configuration. -func rewriteHeaders(r *http.Request, c *transportConfig) { - // Add in JWT headers. - r.Header.Set(teleport.AppJWTHeader, c.jwt) - - if c.app.GetRewrite() == nil || len(c.app.GetRewrite().Headers) == 0 { - return - } - for _, header := range c.app.GetRewrite().Headers { - if common.IsReservedHeader(header.Name) { - c.log.DebugContext(r.Context(), "Not rewriting Teleport reserved header", "header_name", header.Name) - continue - } - values, err := services.ApplyValueTraits(header.Value, c.traits) - if err != nil { - c.log.DebugContext(r.Context(), "Failed to apply traits", - "header_value", header.Value, - "error", err, - ) - continue - } - r.Header.Del(header.Name) - for _, value := range values { - switch http.CanonicalHeaderKey(header.Name) { - case teleport.HostHeader: - r.Host = value - default: - r.Header.Add(header.Name, value) - } - } - } -} - // needsPathRedirect checks if the request should be redirected to a different path. // At the moment, the only time a redirect happens is if URI specified is not // "/" and the public address being requested is "/".