diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 42e7c72432b46..f8a2a3eadbeab 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -30,6 +30,7 @@ import ( mathrand "math/rand" "net" "net/http" + "net/url" "path/filepath" "regexp" "strconv" @@ -55,6 +56,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/util/httpstream" + utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apiserver/pkg/util/wsstream" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" @@ -2164,6 +2166,7 @@ func (f *Forwarder) getExecutor(ctx authContext, sess *clusterSession, req *http tlsConfig: sess.tlsConfig, pingPeriod: f.cfg.ConnPingPeriod, originalHeaders: req.Header, + proxier: sess.getProxier(), }) rt := http.RoundTripper(upgradeRoundTripper) if sess.kubeAPICreds != nil { @@ -2186,6 +2189,7 @@ func (f *Forwarder) getDialer(ctx authContext, sess *clusterSession, req *http.R tlsConfig: sess.tlsConfig, pingPeriod: f.cfg.ConnPingPeriod, originalHeaders: req.Header, + proxier: sess.getProxier(), }) rt := http.RoundTripper(upgradeRoundTripper) if sess.kubeAPICreds != nil { @@ -2319,6 +2323,20 @@ func (s *clusterSession) dial(ctx context.Context, network string) (net.Conn, er return nil, trace.NewAggregate(errs...) } +// getProxier returns the proxier function to use for this session. +// If the target cluster is not served by this teleport service, the proxier +// must be nil to avoid using it through the reverse tunnel. +// If the target cluster is served by this teleport service, the proxier +// must be set to the default proxy function. +func (s *clusterSession) getProxier() func(req *http.Request) (*url.URL, error) { + // When the target cluster is not served by this teleport service, the + // proxier must be nil to avoid using it through the reverse tunnel. + if s.kubeAPICreds == nil { + return nil + } + return utilnet.NewProxierWithNoProxyCIDR(http.ProxyFromEnvironment) +} + // TODO(awly): unit test this func (f *Forwarder) newClusterSession(ctx context.Context, authCtx authContext) (*clusterSession, error) { ctx, span := f.cfg.tracer.Start( diff --git a/lib/kube/proxy/roundtrip.go b/lib/kube/proxy/roundtrip.go index 48f9442d088f9..063e64137ab99 100644 --- a/lib/kube/proxy/roundtrip.go +++ b/lib/kube/proxy/roundtrip.go @@ -40,7 +40,7 @@ import ( utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apimachinery/third_party/forked/golang/netutil" - "github.com/gravitational/teleport/lib/utils" + apiclient "github.com/gravitational/teleport/api/client" ) // SpdyRoundTripper knows how to upgrade an HTTP request to one that supports @@ -71,6 +71,8 @@ type SpdyRoundTripper struct { pingPeriod time.Duration // originalHeaders are the headers that were passed from the original request. originalHeaders http.Header + + proxier func(*http.Request) (*url.URL, error) } var ( @@ -89,12 +91,13 @@ type roundTripperConfig struct { tlsConfig *tls.Config pingPeriod time.Duration originalHeaders http.Header + proxier func(*http.Request) (*url.URL, error) } // NewSpdyRoundTripperWithDialer creates a new SpdyRoundTripper that will use // the specified tlsConfig. This function is mostly meant for unit tests. func NewSpdyRoundTripperWithDialer(cfg roundTripperConfig) *SpdyRoundTripper { - return &SpdyRoundTripper{tlsConfig: cfg.tlsConfig, dialWithContext: cfg.dial, ctx: cfg.ctx, authCtx: cfg.authCtx, pingPeriod: cfg.pingPeriod, originalHeaders: cfg.originalHeaders} + return &SpdyRoundTripper{tlsConfig: cfg.tlsConfig, dialWithContext: cfg.dial, ctx: cfg.ctx, authCtx: cfg.authCtx, pingPeriod: cfg.pingPeriod, originalHeaders: cfg.originalHeaders, proxier: cfg.proxier} } // TLSClientConfig implements pkg/util/net.TLSClientConfigHolder for proper TLS checking during @@ -105,7 +108,7 @@ func (s *SpdyRoundTripper) TLSClientConfig() *tls.Config { // Dial implements k8s.io/apimachinery/pkg/util/net.Dialer. func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { - conn, err := s.dial(req.URL) + conn, err := s.dial(req) if err != nil { return nil, err } @@ -119,37 +122,84 @@ func (s *SpdyRoundTripper) Dial(req *http.Request) (net.Conn, error) { } // dial dials the host specified by url, using TLS if appropriate. -func (s *SpdyRoundTripper) dial(url *url.URL) (net.Conn, error) { - dialAddr := netutil.CanonicalAddr(url) - - if url.Scheme == "http" { - switch { - case s.dialWithContext != nil: - return s.dialWithContext(s.ctx, "tcp", dialAddr) - default: - return net.Dial("tcp", dialAddr) +func (s *SpdyRoundTripper) dial(req *http.Request) (conn net.Conn, err error) { + var proxyURL *url.URL + if s.proxier != nil { + proxyURL, err = s.proxier(req) + if err != nil { + return nil, err } } - // TODO validate the TLSClientConfig is set up? - var conn *tls.Conn - var err error - if s.dialWithContext == nil { - conn, err = tls.Dial("tcp", dialAddr, s.tlsConfig) + if proxyURL == nil { + conn, err = s.dialWithoutProxy(req.URL) } else { - conn, err = utils.TLSDial(s.ctx, s.dialWithContext, "tcp", dialAddr, s.tlsConfig) + conn, err = s.dialWithProxy(req, proxyURL) } if err != nil { return nil, trace.Wrap(err) } + if req.URL.Scheme == "https" { + return s.tlsConn(s.ctx, conn, netutil.CanonicalAddr(req.URL)) + } + return conn, nil +} + +func (s *SpdyRoundTripper) dialWithoutProxy(url *url.URL) (conn net.Conn, err error) { + dialAddr := netutil.CanonicalAddr(url) + switch { + case s.dialWithContext != nil: + conn, err = s.dialWithContext(s.ctx, "tcp", dialAddr) + default: + conn, err = net.Dial("tcp", dialAddr) + } + return conn, trace.Wrap(err) +} + +// tlsConn returns a TLS client side connection using rwc as the underlying transport. +func (s *SpdyRoundTripper) tlsConn(ctx context.Context, rwc net.Conn, targetHost string) (net.Conn, error) { + host, _, err := net.SplitHostPort(targetHost) + if err != nil { + return nil, err + } + + tlsConfig := s.tlsConfig + switch { + case tlsConfig == nil: + tlsConfig = &tls.Config{ServerName: host} + case len(tlsConfig.ServerName) == 0: + tlsConfig = tlsConfig.Clone() + tlsConfig.ServerName = host + } + tlsConn := tls.Client(rwc, tlsConfig) + // Client handshake will verify the server hostname and cert chain. That // way we can err our before first read/write. - if err := conn.Handshake(); err != nil { + if err := tlsConn.HandshakeContext(ctx); err != nil { + tlsConn.Close() return nil, trace.Wrap(err) } - return conn, nil + return tlsConn, nil +} + +// dialWithProxy dials the host specified by url through an http or an socks5 proxy. +func (s *SpdyRoundTripper) dialWithProxy(req *http.Request, proxyURL *url.URL) (net.Conn, error) { + // ensure we use a canonical host with proxyReq + targetHost := netutil.CanonicalAddr(req.URL) + + netDialer := &net.Dialer{ + Timeout: 30 * time.Second, + } + + proxyDialConn, err := apiclient.DialProxyWithDialer( + s.ctx, + proxyURL, + targetHost, + netDialer, + ) + return proxyDialConn, trace.Wrap(err) } // RoundTrip executes the Request and upgrades it. After a successful upgrade, @@ -186,7 +236,6 @@ func (s *SpdyRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) conn, ), ) - resp, err := http.ReadResponse(responseReader, nil) if err != nil { if conn != nil {