Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/kube/proxy/kube_creds.go
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ func newDynamicKubeCreds(ctx context.Context, cfg dynamicCredsConfig) (*dynamicK
func (d *dynamicKubeCreds) getTLSConfig() *tls.Config {
d.RLock()
defer d.RUnlock()
return d.staticCreds.tlsConfig
return d.staticCreds.getTLSConfig()
}

func (d *dynamicKubeCreds) getTransportConfig() *transport.Config {
Expand Down
13 changes: 9 additions & 4 deletions lib/kube/proxy/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ type dialContextFunc func(context.Context, string, string) (net.Conn, error)
// The transport is cached in the forwarder so that it can be reused for future
// requests. If the transport is not cached, a new one is created and cached.
func (f *Forwarder) transportForRequestWithImpersonation(sess *clusterSession) (http.RoundTripper, *tls.Config, error) {
// If the session has a kube API credentials, it means that the next hop is
// a Kubernetes API server. In this case, we can use the provided credentials
// to dial the next hop directly and never cache the transport.
if sess.kubeAPICreds != nil {
// If agent is running in agent mode, get the transport from the configured cluster
// credentials.
return sess.kubeAPICreds.getTransport(), sess.kubeAPICreds.getTLSConfig(), nil
}

// If the cluster is remote, the key is the teleport cluster name.
// If the cluster is local, the key is the teleport cluster name and the kubernetes
// cluster name: <teleport-cluster-name>/<kubernetes-cluster-name>.
Expand All @@ -73,10 +82,6 @@ func (f *Forwarder) transportForRequestWithImpersonation(sess *clusterSession) (
if sess.teleportCluster.isRemote {
// If the cluster is remote, create a new transport for the remote cluster.
httpTransport, tlsConfig, err = f.newRemoteClusterTransport(sess.teleportCluster.name)
} else if sess.kubeAPICreds != nil {
// If agent is running in agent mode, get the transport from the configured cluster
// credentials.
httpTransport, tlsConfig = sess.kubeAPICreds.getTransport(), sess.kubeAPICreds.getTLSConfig()
} else if f.cfg.ReverseTunnelSrv != nil {
// If agent is running in proxy mode, create a new transport for the local cluster.
httpTransport, tlsConfig, err = f.newLocalClusterTransport(sess.kubeClusterName)
Expand Down
44 changes: 44 additions & 0 deletions lib/kube/proxy/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ package proxy

import (
"context"
"crypto/tls"
"fmt"
"net"
"testing"

"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"

Expand Down Expand Up @@ -137,3 +139,45 @@ func newKubeServerWithProxyIDs(t *testing.T, hostname, hostID string, proxyIds [
ks.Spec.ProxyIDs = proxyIds
return ks
}

func TestDirectTransportNotCached(t *testing.T) {
t.Parallel()

transportClients, err := utils.NewFnCache(utils.FnCacheConfig{
TTL: transportCacheTTL,
Clock: clockwork.NewFakeClock(),
})
require.NoError(t, err)

forwarder := &Forwarder{
ctx: context.Background(),
cachedTransport: transportClients,
}

kubeAPICreds := &dynamicKubeCreds{
staticCreds: &staticKubeCreds{
tlsConfig: &tls.Config{
ServerName: "localhost",
},
},
}

clusterSess := &clusterSession{
kubeAPICreds: kubeAPICreds,
authContext: authContext{
kubeClusterName: "b",
teleportCluster: teleportClusterClient{
name: "a",
},
},
}

_, tlsConfig, err := forwarder.transportForRequestWithImpersonation(clusterSess)
require.NoError(t, err)
require.Equal(t, "localhost", tlsConfig.ServerName)

kubeAPICreds.staticCreds.tlsConfig.ServerName = "example.com"
_, tlsConfig, err = forwarder.transportForRequestWithImpersonation(clusterSess)
require.NoError(t, err)
require.Equal(t, "example.com", tlsConfig.ServerName)
}