Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions api/utils/iterutils/iter.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,18 @@ func Map[In, Out any](f func(In) Out, seq iter.Seq[In]) iter.Seq[Out] {
}
}
}

// Filter returns an iterator over seq that only includes the values v for which
// f(v) is true.
//
// Copied from https://github.com/golang/go/issues/61898. We should switch to an
// official package once it is available.
func Filter[V any](f func(V) bool, seq iter.Seq[V]) iter.Seq[V] {
return func(yield func(V) bool) {
for v := range seq {
if f(v) && !yield(v) {
return
}
}
}
}
18 changes: 18 additions & 0 deletions api/utils/iterutils/iter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,21 @@ func ExampleMap() {
// HELLO WORLD
// FOO
}

func ExampleFilter() {
inputs := []string{
"a",
"bb",
"ccc",
"dddd",
}
isOddLen := func(s string) bool {
return len(s)%2 == 1
}
for filtered := range Filter(isOddLen, slices.Values(inputs)) {
fmt.Println(filtered)
}
// Output:
// a
// ccc
}
28 changes: 28 additions & 0 deletions lib/services/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import (
"context"
"fmt"
"iter"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"strconv"
Expand All @@ -34,7 +36,9 @@ import (
"k8s.io/apimachinery/pkg/util/validation"
kyaml "k8s.io/apimachinery/pkg/util/yaml"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -310,3 +314,27 @@ func getClusterDomain() string {
}
return "cluster.local"
}

// RewriteHeadersAndApplyValueTraits rewrites the provided request's headers
// while applying value traits to them.
func RewriteHeadersAndApplyValueTraits(r *http.Request, rewrites iter.Seq[*types.Header], traits wrappers.Traits, log *slog.Logger) {
for header := range rewrites {
values, err := ApplyValueTraits(header.Value, traits)
if err != nil {
log.DebugContext(r.Context(), "Failed to apply traits",
"header_value", header.Value,
"error", err,
)
continue
}
r.Header.Del(header.Name)
for _, value := range values {
switch http.CanonicalHeaderKey(header.Name) {
case teleport.HostHeader:
r.Host = value
default:
r.Header.Add(header.Name, value)
}
}
}
}
27 changes: 27 additions & 0 deletions lib/services/app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@
package services

import (
"log/slog"
"net/http"
"net/http/httptest"
"slices"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -375,3 +380,25 @@ func TestInsecureSkipVerify(t *testing.T) {
require.Equal(t, tt.expected, result)
}
}

func TestRewriteHeadersAndApplyValueTraits(t *testing.T) {
r := httptest.NewRequest("GET", "/foo", nil)
r.Header.Set("x-no-rewrite", "no-rewrite")
rewrites := []*types.Header{
{Name: "host", Value: "1.2.3.4"},
{Name: "x-rewrite", Value: "{{external.rewrite}}"},
// Missing traits should log a debug message that this rewrite is skipped.
{Name: "x-bad-rewrite", Value: "{{external.bad_rewrite}}"},
}
traits := map[string][]string{
"rewrite": {"value1", "value2"},
}
RewriteHeadersAndApplyValueTraits(r, slices.Values(rewrites), traits, slog.Default())

assert.Equal(t, "1.2.3.4", r.Host)
wantHeaders := make(http.Header)
wantHeaders.Add("x-rewrite", "value1")
wantHeaders.Add("x-rewrite", "value2")
wantHeaders.Add("x-no-rewrite", "no-rewrite")
assert.Equal(t, wantHeaders, r.Header)
}
22 changes: 22 additions & 0 deletions lib/srv/app/common/header_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,14 @@
package common

import (
"context"
"iter"
"log/slog"
"net/http/httputil"
"slices"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/iterutils"
"github.com/gravitational/teleport/lib/httplib/reverseproxy"
)

Expand Down Expand Up @@ -55,3 +61,19 @@ func (hr *HeaderRewriter) Rewrite(req *httputil.ProxyRequest) {
req.Out.Header.Set(XForwardedSSL, sslOff)
}
}

// AppRewriteHeaders returns an iterator for app headers to rewrite. Reserved
// headers are skipped.
func AppRewriteHeaders(ctx context.Context, rewrite *types.Rewrite, log *slog.Logger) iter.Seq[*types.Header] {
var headers []*types.Header
if rewrite != nil {
headers = rewrite.Headers
}
return iterutils.Filter(func(header *types.Header) bool {
if IsReservedHeader(header.Name) {
log.DebugContext(ctx, "Not rewriting Teleport reserved header", "header_name", header.Name)
return false
}
return true
}, slices.Values(headers))
}
41 changes: 41 additions & 0 deletions lib/srv/app/common/header_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@
package common

import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net/http"
"net/http/httputil"
"net/url"
"slices"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/httplib/reverseproxy"
)

Expand Down Expand Up @@ -148,3 +153,39 @@ func TestHeaderRewriter(t *testing.T) {
})
}
}

func TestAppRewriteHeaders(t *testing.T) {
tests := []struct {
name string
rewrite *types.Rewrite
wantHeaders []*types.Header
}{
{
name: "no rewrite",
rewrite: nil,
wantHeaders: nil,
},
{
name: "reserved header is filtered",
rewrite: &types.Rewrite{
Headers: []*types.Header{
{Name: "test-key-1", Value: "test-value-1"},
{Name: "teleport-jwt-assertion", Value: "teleport-jwt-assertion-value"},
{Name: "test-key-2", Value: "test-value-2"},
{Name: "X-Real-Ip", Value: "1.2.3.4"},
},
},
wantHeaders: []*types.Header{
{Name: "test-key-1", Value: "test-value-1"},
{Name: "test-key-2", Value: "test-value-2"},
},
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
actualHeaders := AppRewriteHeaders(context.Background(), test.rewrite, slog.Default())
require.Equal(t, test.wantHeaders, slices.Collect(actualHeaders))
})
}
}
78 changes: 78 additions & 0 deletions lib/srv/app/common/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Teleport
* Copyright (C) 2025 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package common

import (
"context"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/tlsca"
)

// AppTokenGenerator defines an interface for generating JWT token for an
// application (by auth).
type AppTokenGenerator interface {
GenerateAppToken(context.Context, types.GenerateAppTokenRequest) (string, error)
}

// GenerateJWTAndTraits is helper that generates a JWT for an application and
// populates the user traits with the result JWT for templating.
func GenerateJWTAndTraits(
ctx context.Context,
identity *tlsca.Identity,
app types.Application,
generator AppTokenGenerator,
) (string, wrappers.Traits, error) {
rewrite := app.GetRewrite()
traits := identity.Traits
roles := identity.Groups
if rewrite != nil {
switch rewrite.JWTClaims {
case types.JWTClaimsRewriteNone:
traits = nil
roles = nil
case types.JWTClaimsRewriteRoles:
traits = nil
case types.JWTClaimsRewriteTraits:
roles = nil
case "", types.JWTClaimsRewriteRolesAndTraits:
}
}

// Request a JWT token that will be attached to all requests.
jwt, err := generator.GenerateAppToken(ctx, types.GenerateAppTokenRequest{
Username: identity.Username,
Roles: roles,
Traits: traits,
URI: app.GetURI(),
Expires: identity.Expires,
})
if err != nil {
return "", nil, trace.Wrap(err)
}
if traits == nil {
traits = make(wrappers.Traits)
}
traits[constants.TraitJWT] = []string{jwt}
return jwt, traits, trace.Wrap(err)
}
33 changes: 1 addition & 32 deletions lib/srv/app/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,8 @@ import (
"github.com/gravitational/trace"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/recorder"
"github.com/gravitational/teleport/lib/httplib/reverseproxy"
Expand Down Expand Up @@ -146,40 +144,11 @@ func (c *ConnectionsHandler) newSessionChunk(ctx context.Context, identity *tlsc
// withJWTTokenForwarder is a sessionOpt that creates a forwarder that attaches
// a generated JWT token to all requests.
func (c *ConnectionsHandler) withJWTTokenForwarder(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error {
rewrite := app.GetRewrite()
traits := identity.Traits
roles := identity.Groups
if rewrite != nil {
switch rewrite.JWTClaims {
case types.JWTClaimsRewriteNone:
traits = nil
roles = nil
case types.JWTClaimsRewriteRoles:
traits = nil
case types.JWTClaimsRewriteTraits:
roles = nil
case "", types.JWTClaimsRewriteRolesAndTraits:
}
}

// Request a JWT token that will be attached to all requests.
jwt, err := c.cfg.AuthClient.GenerateAppToken(ctx, types.GenerateAppTokenRequest{
Username: identity.Username,
Roles: roles,
Traits: traits,
URI: app.GetURI(),
Expires: identity.Expires,
})
jwt, traits, err := common.GenerateJWTAndTraits(ctx, identity, app, c.cfg.AuthClient)
if err != nil {
return trace.Wrap(err)
}

// Add JWT token to the traits so it can be used in headers templating.
if traits == nil {
traits = make(wrappers.Traits)
}
traits[constants.TraitJWT] = []string{jwt}

// Create a rewriting transport that will be used to forward requests.
transport, err := newTransport(c.closeContext,
&transportConfig{
Expand Down
Loading
Loading