Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions lib/kube/proxy/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package proxy

import (
"context"
"crypto/tls"
"fmt"
"net"
"net/url"
Expand All @@ -35,6 +36,7 @@ import (
_ "k8s.io/client-go/plugin/pkg/client/auth/azure"
_ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
"k8s.io/client-go/rest"
"k8s.io/client-go/transport"

"github.com/gravitational/teleport/api/types"
kubeutils "github.com/gravitational/teleport/lib/kube/utils"
Expand Down Expand Up @@ -166,13 +168,42 @@ func extractKubeCreds(ctx context.Context, cluster string, clientCfg *rest.Confi
return nil, trace.Wrap(err, "failed to generate transport config from kubeconfig: %v", err)
}

transport, err := newDirectTransports(tlsConfig, transportConfig)
if err != nil {
return nil, trace.Wrap(err, "failed to generate transport from kubeconfig: %v", err)
}

log.Debug("Initialized Kubernetes credentials")
return &staticKubeCreds{
tlsConfig: tlsConfig,
transportConfig: transportConfig,
targetAddr: targetAddr,
kubeClient: client,
clientRestCfg: clientCfg,
transport: transport,
}, nil
}

// newDirectTransports creates a new http.Transport that will be used to connect to the Kubernetes API server.
// It is a direct connection, not going through a proxy.
func newDirectTransports(tlsConfig *tls.Config, transportConfig *transport.Config) (httpTransport, error) {
h1Transport, err := wrapTransport(newH1Transport(tlsConfig, nil), transportConfig)
if err != nil {
return httpTransport{}, trace.Wrap(err)
}

h2HTTPTransport, err := newH2Transport(tlsConfig, nil)
if err != nil {
return httpTransport{}, trace.Wrap(err)
}
h2Transport, err := wrapTransport(h2HTTPTransport, transportConfig)
if err != nil {
return httpTransport{}, trace.Wrap(err)
}

return httpTransport{
h1Transport: h1Transport,
h2Transport: h2Transport,
}, nil
}

Expand Down
2 changes: 2 additions & 0 deletions lib/kube/proxy/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,11 @@ current-context: foo
require.Empty(t, cmp.Diff(got, tt.want,
cmp.AllowUnexported(staticKubeCreds{}),
cmp.AllowUnexported(kubeDetails{}),
cmp.AllowUnexported(httpTransport{}),
cmp.Comparer(func(a, b *transport.Config) bool { return (a == nil) == (b == nil) }),
cmp.Comparer(func(a, b *kubernetes.Clientset) bool { return (a == nil) == (b == nil) }),
cmp.Comparer(func(a, b *rest.Config) bool { return (a == nil) == (b == nil) }),
cmp.Comparer(func(a, b httpTransport) bool { return true }),
))
})
}
Expand Down
145 changes: 48 additions & 97 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ import (
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/net/http2"
corev1 "k8s.io/api/core/v1"
kubeerrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -168,6 +167,10 @@ type ForwarderConfig struct {
TracerProvider oteltrace.TracerProvider
// Tracer is used to start spans.
tracer oteltrace.Tracer
// ConnTLSConfig is the TLS client configuration to use when connecting to
// the upstream Teleport proxy or Kubernetes service when forwarding requests
// using the forward identity (i.e. proxy impersonating a user) method.
ConnTLSConfig *tls.Config
}

// CheckAndSetDefaults checks and sets default values
Expand Down Expand Up @@ -227,8 +230,13 @@ func (f *ForwarderConfig) CheckAndSetDefaults() error {

switch f.KubeServiceType {
case KubeService:
case ProxyService:
case LegacyProxyService:
case ProxyService, LegacyProxyService:
if f.ConnTLSConfig == nil {
return trace.BadParameter("missing parameter TLSConfig")
}
// Reset the ServerName to ensure that the proxy does not use the
// proxy's hostname as the SNI when connecting to the Kubernetes service.
f.ConnTLSConfig.ServerName = ""
default:
return trace.BadParameter("unknown value for KubeServiceType")
}
Expand Down Expand Up @@ -256,6 +264,15 @@ func NewForwarder(cfg ForwarderConfig) (*Forwarder, error) {
return nil, trace.Wrap(err)
}

// TODO (tigrato): remove this once we have a better way to handle
// deleting expired entried clusters and kube_servers entries.
// In the meantime, we need to make sure that the cache is cleaned
// from time to time.
transportClients, err := ttlmap.New(defaults.ClientCacheSize)
if err != nil {
return nil, trace.Wrap(err)
}

closeCtx, close := context.WithCancel(cfg.Context)

fwd := &Forwarder{
Expand All @@ -270,7 +287,8 @@ func NewForwarder(cfg ForwarderConfig) (*Forwarder, error) {
ReadBufferSize: 1024,
WriteBufferSize: 1024,
},
clusterDetails: make(map[string]*kubeDetails),
clusterDetails: make(map[string]*kubeDetails),
cachedTransport: transportClients,
}
router := httprouter.New()

Expand Down Expand Up @@ -359,6 +377,13 @@ type Forwarder struct {
// local kubernetes clusters. If the service type is Proxy, it will
// use the heartbeat clusters.
getKubernetesServersForKubeCluster getKubeServersByNameFunc

// cachedTransport is a cache of http.Transport objects used to
// connect to Teleport services.
// TODO(tigrato): Implement a cache eviction policy using watchers.
cachedTransport *ttlmap.TTLMap
// cachedTransportMu is a mutex used to protect the cachedTransport.
cachedTransportMu sync.Mutex
}

// getKubeServersByNameFunc is a function that returns a list of
Expand Down Expand Up @@ -1514,7 +1539,7 @@ func (f *Forwarder) execNonInteractive(ctx *authContext, w http.ResponseWriter,
return nil, trace.AccessDenied("insufficient permissions to launch non-interactive session")
}

eventPodMeta := request.eventPodMeta(request.context, sess.creds)
eventPodMeta := request.eventPodMeta(request.context, sess.kubeAPICreds)

sessionStart := f.cfg.Clock.Now().UTC()

Expand Down Expand Up @@ -1901,8 +1926,8 @@ func (f *Forwarder) setupForwardingHeaders(sess *clusterSession, req *http.Reque
// We only have a direct host to provide when using local creds.
// Otherwise, use kube-teleport-proxy-alpn.teleport.cluster.local to pass TLS handshake and leverage TLS Routing.
req.URL.Host = fmt.Sprintf("%s%s", constants.KubeTeleportProxyALPNPrefix, constants.APIDomain)
if sess.creds != nil {
req.URL.Host = sess.creds.getTargetAddr()
if sess.kubeAPICreds != nil {
req.URL.Host = sess.kubeAPICreds.getTargetAddr()
}

// add origin headers so the service consuming the request on the other site
Expand Down Expand Up @@ -2100,9 +2125,9 @@ func (f *Forwarder) getExecutor(ctx authContext, sess *clusterSession, req *http
originalHeaders: req.Header,
})
rt := http.RoundTripper(upgradeRoundTripper)
if sess.creds != nil {
if sess.kubeAPICreds != nil {
var err error
rt, err = sess.creds.wrapTransport(rt)
rt, err = sess.kubeAPICreds.wrapTransport(rt)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2122,9 +2147,9 @@ func (f *Forwarder) getDialer(ctx authContext, sess *clusterSession, req *http.R
originalHeaders: req.Header,
})
rt := http.RoundTripper(upgradeRoundTripper)
if sess.creds != nil {
if sess.kubeAPICreds != nil {
var err error
rt, err = sess.creds.wrapTransport(rt)
rt, err = sess.kubeAPICreds.wrapTransport(rt)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -2140,10 +2165,13 @@ func (f *Forwarder) getDialer(ctx authContext, sess *clusterSession, req *http.R
// x509 short lived credentials, forwarding proxies and other data
type clusterSession struct {
authContext
parent *Forwarder
creds kubeCreds
tlsConfig *tls.Config
forwarder *forward.Forwarder
parent *Forwarder
// kubeAPICreds are the credentials used to authenticate to the Kubernetes API server.
// It is non-nil if the kubernetes cluster is served by this teleport service,
// nil otherwise.
kubeAPICreds kubeCreds
tlsConfig *tls.Config
forwarder *forward.Forwarder
// noAuditEvents is true if this teleport service should leave audit event
// logging to another service.
noAuditEvents bool
Expand Down Expand Up @@ -2335,7 +2363,7 @@ func (f *Forwarder) newClusterSessionLocal(ctx authContext) (*clusterSession, er
return &clusterSession{
parent: f,
authContext: ctx,
creds: details.kubeCreds,
kubeAPICreds: details.kubeCreds,
kubeClusterEndpoints: []kubeClusterEndpoint{{addr: details.getTargetAddr()}},
tlsConfig: details.getTLSConfig().Clone(),
}, nil
Expand Down Expand Up @@ -2371,30 +2399,14 @@ func (f *Forwarder) newClusterSessionDirect(ctx context.Context, authCtx authCon
// The reason being is that streaming requests are going to be upgraded to SPDY, which is only
// supported coming from an HTTP1 request.
func (f *Forwarder) makeSessionForwarder(sess *clusterSession) (*forward.Forwarder, error) {
var err error
transport := f.newTransport(sess.DialWithContext, sess.tlsConfig)

if sess.upgradeToHTTP2 {
// Upgrade transport to h2 where HTTP_PROXY and HTTPS_PROXY
// envs are not take into account purposely.
if err := http2.ConfigureTransport(transport); err != nil {
return nil, trace.Wrap(err)
}
}

rt := http.RoundTripper(transport)
if sess.creds != nil {
rt, err = sess.creds.wrapTransport(rt)
if err != nil {
return nil, trace.Wrap(err)
}
transport, err := f.transportForRequest(sess)
if err != nil {
return nil, trace.Wrap(err)
}

rt = tracehttp.NewTransport(rt)

forwarder, err := forward.New(
forward.FlushInterval(100*time.Millisecond),
forward.RoundTripper(rt),
forward.RoundTripper(transport),
forward.WebsocketDial(sess.Dial),
forward.Logger(f.log),
forward.ErrorHandler(fwdutils.ErrorHandlerFunc(f.formatForwardResponseError)),
Expand All @@ -2406,25 +2418,6 @@ func (f *Forwarder) makeSessionForwarder(sess *clusterSession) (*forward.Forward
return forwarder, nil
}

// DialContextFunc is a context network dialer function that returns a network connection
type DialContextFunc func(context.Context, string, string) (net.Conn, error)

func (f *Forwarder) newTransport(dial DialContextFunc, tlsConfig *tls.Config) *http.Transport {
return &http.Transport{
DialContext: dial,
TLSClientConfig: tlsConfig,
// Increase the size of the connection pool. This substantially improves the
// performance of Teleport under load as it reduces the number of TLS
// handshakes performed.
MaxIdleConns: defaults.HTTPMaxIdleConns,
MaxIdleConnsPerHost: defaults.HTTPMaxIdleConnsPerHost,
// IdleConnTimeout defines the maximum amount of time before idle connections
// are closed. Leaving this unset will lead to connections open forever and
// will cause memory leaks in a long running process.
IdleConnTimeout: defaults.HTTPIdleTimeout,
}
}

// getOrCreateRequestContext creates a new certificate request for a given context,
// if there is no active CSR request in progress, or returns an existing one.
// if the new context has been created, cancel function is returned as a
Expand All @@ -2447,47 +2440,6 @@ func (f *Forwarder) getOrCreateRequestContext(key string) (context.Context, cont
}
}

func (f *Forwarder) getOrRequestClientCreds(tracingCtx context.Context, authCtx authContext) (*tls.Config, error) {
c := f.getClientCreds(authCtx)
if c == nil {
return f.serializedRequestClientCreds(tracingCtx, authCtx)
}
return c, nil
}

func (f *Forwarder) getClientCreds(ctx authContext) *tls.Config {
f.mu.Lock()
defer f.mu.Unlock()
creds, ok := f.clientCredentials.Get(ctx.key())
if !ok {
return nil
}
c := creds.(*tls.Config)
if !validClientCreds(f.cfg.Clock, c) {
return nil
}
return c
}

func (f *Forwarder) saveClientCreds(ctx authContext, c *tls.Config) error {
f.mu.Lock()
defer f.mu.Unlock()
return f.clientCredentials.Set(ctx.key(), c, ctx.sessionTTL)
}

func validClientCreds(clock clockwork.Clock, c *tls.Config) bool {
if len(c.Certificates) == 0 || len(c.Certificates[0].Certificate) == 0 {
return false
}
crt, err := x509.ParseCertificate(c.Certificates[0].Certificate[0])
if err != nil {
return false
}
// Make sure that the returned cert will be valid for at least 1 more
// minute.
return clock.Now().Add(time.Minute).Before(crt.NotAfter)
}

func (f *Forwarder) serializedRequestClientCreds(tracingCtx context.Context, authContext authContext) (*tls.Config, error) {
ctx, cancel := f.getOrCreateRequestContext(authContext.key())
if cancel != nil {
Expand Down Expand Up @@ -2713,7 +2665,6 @@ func (f *Forwarder) listPods(authCtx *authContext, w http.ResponseWriter, req *h
}

f.emitAuditEvent(authCtx, req, sess, status)

return nil, nil
}

Expand Down
15 changes: 14 additions & 1 deletion lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,7 @@ func TestKubeFwdHTTPProxyEnv(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
f := newMockForwader(ctx, t)

authCtx := mockAuthCtx(ctx, t, "kube-cluster", false)

lockWatcher, err := services.NewLockWatcher(ctx, services.LockWatcherConfig{
Expand Down Expand Up @@ -1194,6 +1195,10 @@ func TestKubeFwdHTTPProxyEnv(t *testing.T) {
return rt
}

h2Transport, err := newH2Transport(&tls.Config{
InsecureSkipVerify: true,
}, nil)
require.NoError(t, err)
f.clusterDetails = map[string]*kubeDetails{
"local": {
kubeCreds: &staticKubeCreds{
Expand All @@ -1202,6 +1207,12 @@ func TestKubeFwdHTTPProxyEnv(t *testing.T) {
transportConfig: &transport.Config{
WrapTransport: checkTransportProxy,
},
transport: httpTransport{
h1Transport: newH1Transport(&tls.Config{
InsecureSkipVerify: true,
}, nil),
h2Transport: h2Transport,
},
},
},
}
Expand Down Expand Up @@ -1243,7 +1254,8 @@ func TestKubeFwdHTTPProxyEnv(t *testing.T) {
func newMockForwader(ctx context.Context, t *testing.T) *Forwarder {
clientCreds, err := ttlmap.New(defaults.ClientCacheSize)
require.NoError(t, err)

cachedTransport, err := ttlmap.New(defaults.ClientCacheSize)
require.NoError(t, err)
csrClient, err := newMockCSRClient()
require.NoError(t, err)

Expand All @@ -1262,6 +1274,7 @@ func newMockForwader(ctx context.Context, t *testing.T) *Forwarder {
clientCredentials: clientCreds,
activeRequests: make(map[string]context.Context),
ctx: ctx,
cachedTransport: cachedTransport,
}
}

Expand Down
Loading