Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
mathrand "math/rand"
"net"
"net/http"
"net/url"
"path/filepath"
"regexp"
"strconv"
Expand All @@ -55,6 +56,7 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/util/httpstream"
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apiserver/pkg/util/wsstream"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
Expand Down Expand Up @@ -2164,6 +2166,7 @@ func (f *Forwarder) getExecutor(ctx authContext, sess *clusterSession, req *http
tlsConfig: sess.tlsConfig,
pingPeriod: f.cfg.ConnPingPeriod,
originalHeaders: req.Header,
proxier: sess.getProxier(),
})
rt := http.RoundTripper(upgradeRoundTripper)
if sess.kubeAPICreds != nil {
Expand All @@ -2186,6 +2189,7 @@ func (f *Forwarder) getDialer(ctx authContext, sess *clusterSession, req *http.R
tlsConfig: sess.tlsConfig,
pingPeriod: f.cfg.ConnPingPeriod,
originalHeaders: req.Header,
proxier: sess.getProxier(),
})
rt := http.RoundTripper(upgradeRoundTripper)
if sess.kubeAPICreds != nil {
Expand Down Expand Up @@ -2319,6 +2323,20 @@ func (s *clusterSession) dial(ctx context.Context, network string) (net.Conn, er
return nil, trace.NewAggregate(errs...)
}

// getProxier returns the proxier function to use for this session.
// If the target cluster is not served by this teleport service, the proxier
// must be nil to avoid using it through the reverse tunnel.
// If the target cluster is served by this teleport service, the proxier
// must be set to the default proxy function.
func (s *clusterSession) getProxier() func(req *http.Request) (*url.URL, error) {
// When the target cluster is not served by this teleport service, the
// proxier must be nil to avoid using it through the reverse tunnel.
if s.kubeAPICreds == nil {
return nil
}
return utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment)
}

// TODO(awly): unit test this
func (f *Forwarder) newClusterSession(ctx context.Context, authCtx authContext) (*clusterSession, error) {
ctx, span := f.cfg.tracer.Start(
Expand Down
91 changes: 70 additions & 21 deletions lib/kube/proxy/roundtrip.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import (
utilnet "k8s.io/apimachinery/pkg/util/net"
"k8s.io/apimachinery/third_party/forked/golang/netutil"

"github.com/gravitational/teleport/lib/utils"
apiclient "github.com/gravitational/teleport/api/client"
)

// SpdyRoundTripper knows how to upgrade an HTTP request to one that supports
Expand Down Expand Up @@ -71,6 +71,8 @@ type SpdyRoundTripper struct {
pingPeriod time.Duration
// originalHeaders are the headers that were passed from the original request.
originalHeaders http.Header

proxier func(*http.Request) (*url.URL, error)
}

var (
Expand All @@ -89,12 +91,13 @@ type roundTripperConfig struct {
tlsConfig *tls.Config
pingPeriod time.Duration
originalHeaders http.Header
proxier func(*http.Request) (*url.URL, error)
}

// NewSpdyRoundTripperWithDialer creates a new SpdyRoundTripper that will use
// the specified tlsConfig. This function is mostly meant for unit tests.
func NewSpdyRoundTripperWithDialer(cfg roundTripperConfig) *SpdyRoundTripper {
return &SpdyRoundTripper{tlsConfig: cfg.tlsConfig, dialWithContext: cfg.dial, ctx: cfg.ctx, authCtx: cfg.authCtx, pingPeriod: cfg.pingPeriod, originalHeaders: cfg.originalHeaders}
return &SpdyRoundTripper{tlsConfig: cfg.tlsConfig, dialWithContext: cfg.dial, ctx: cfg.ctx, authCtx: cfg.authCtx, pingPeriod: cfg.pingPeriod, originalHeaders: cfg.originalHeaders, proxier: cfg.proxier}
}

// TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during
Expand All @@ -105,7 +108,7 @@ func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config {

// Dial implements k8s.io/apimachinery/pkg/util/net.Dialer.
func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
conn, err := s.dial(req.URL)
conn, err := s.dial(req)
if err != nil {
return nil, err
}
Expand All @@ -119,37 +122,84 @@ func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) {
}

// dial dials the host specified by url, using TLS if appropriate.
func (s *SpdyRoundTripper) dial(url *url.URL) (net.Conn, error) {
dialAddr := netutil.CanonicalAddr(url)

if url.Scheme == "http" {
switch {
case s.dialWithContext != nil:
return s.dialWithContext(s.ctx, "tcp", dialAddr)
default:
return net.Dial("tcp", dialAddr)
func (s *SpdyRoundTripper) dial(req *http.Request) (conn net.Conn, err error) {
var proxyURL *url.URL
if s.proxier != nil {
proxyURL, err = s.proxier(req)
if err != nil {
return nil, err
}
}

// TODO validate the TLSClientConfig is set up?
var conn *tls.Conn
var err error
if s.dialWithContext == nil {
conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig)
if proxyURL == nil {
conn, err = s.dialWithoutProxy(req.URL)
} else {
conn, err = utils.TLSDial(s.ctx, s.dialWithContext, "tcp", dialAddr, s.tlsConfig)
conn, err = s.dialWithProxy(req, proxyURL)
}
if err != nil {
return nil, trace.Wrap(err)
}

if req.URL.Scheme == "https" {
return s.tlsConn(s.ctx, conn, netutil.CanonicalAddr(req.URL))
}
return conn, nil
}

func (s *SpdyRoundTripper) dialWithoutProxy(url *url.URL) (conn net.Conn, err error) {
dialAddr := netutil.CanonicalAddr(url)
switch {
case s.dialWithContext != nil:
conn, err = s.dialWithContext(s.ctx, "tcp", dialAddr)
default:
conn, err = net.Dial("tcp", dialAddr)
}
return conn, trace.Wrap(err)
}

// tlsConn returns a TLS client side connection using rwc as the underlying transport.
func (s *SpdyRoundTripper) tlsConn(ctx context.Context, rwc net.Conn, targetHost string) (net.Conn, error) {
host, _, err := net.SplitHostPort(targetHost)
if err != nil {
return nil, err
}

tlsConfig := s.tlsConfig
switch {
case tlsConfig == nil:
tlsConfig = &tls.Config{ServerName: host}
case len(tlsConfig.ServerName) == 0:
tlsConfig = tlsConfig.Clone()
tlsConfig.ServerName = host
}
tlsConn := tls.Client(rwc, tlsConfig)

// Client handshake will verify the server hostname and cert chain. That
// way we can err our before first read/write.
if err := conn.Handshake(); err != nil {
if err := tlsConn.HandshakeContext(ctx); err != nil {
tlsConn.Close()
return nil, trace.Wrap(err)
}

return conn, nil
return tlsConn, nil
}

// dialWithProxy dials the host specified by url through an http or an socks5 proxy.
func (s *SpdyRoundTripper) dialWithProxy(req *http.Request, proxyURL *url.URL) (net.Conn, error) {
// ensure we use a canonical host with proxyReq
targetHost := netutil.CanonicalAddr(req.URL)

netDialer := &net.Dialer{
Timeout: 30 * time.Second,
}

proxyDialConn, err := apiclient.DialProxyWithDialer(
s.ctx,
proxyURL,
targetHost,
netDialer,
)
return proxyDialConn, trace.Wrap(err)
}

// RoundTrip executes the Request and upgrades it. After a successful upgrade,
Expand Down Expand Up @@ -186,7 +236,6 @@ func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error)
conn,
),
)

resp, err := http.ReadResponse(responseReader, nil)
if err != nil {
if conn != nil {
Expand Down