Skip to content

Commit

Permalink
Add a .tsh/config file and add support for configuring custom http he…
Browse files Browse the repository at this point in the history
…aders
  • Loading branch information
Alex McGrath committed Feb 25, 2022
1 parent 123fc2c commit 5210182
Show file tree
Hide file tree
Showing 13 changed files with 204 additions and 44 deletions.
6 changes: 4 additions & 2 deletions api/client/contextdialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ func NewDirectDialer(keepAlivePeriod, dialTimeout time.Duration) ContextDialer {
func NewProxyDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
dialer := newTunnelDialer(ssh, keepAlivePeriod, dialTimeout)
return ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -91,7 +92,8 @@ func newTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Dur
// through the SSH reverse tunnel on the proxy.
func newTLSRoutingTunnelDialer(ssh ssh.ClientConfig, keepAlivePeriod, dialTimeout time.Duration, discoveryAddr string, insecure bool) ContextDialer {
return ContextDialerFunc(func(ctx context.Context, network, addr string) (conn net.Conn, err error) {
tunnelAddr, err := webclient.GetTunnelAddr(ctx, discoveryAddr, insecure, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: discoveryAddr, Insecure: insecure})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
61 changes: 37 additions & 24 deletions api/client/webclient/webclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,22 @@ import (
log "github.com/sirupsen/logrus"
)

type Config struct {
Context context.Context
ProxyAddr string
Insecure bool
Pool *x509.CertPool
ConnectorName string
ExtraHeaders map[string]string
}

// newWebClient creates a new client to the HTTPS web proxy.
func newWebClient(insecure bool, pool *x509.CertPool) *http.Client {
func newWebClient(cfg *Config) *http.Client {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: pool,
InsecureSkipVerify: insecure,
RootCAs: cfg.Pool,
InsecureSkipVerify: cfg.Insecure,
},
},
}
Expand All @@ -56,9 +65,13 @@ func newWebClient(insecure bool, pool *x509.CertPool) *http.Client {
// * The target host must resolve to the loopback address.
// If these conditions are not met, then the plain-HTTP fallback is not allowed,
// and a the HTTPS failure will be considered final.
func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (*http.Response, error) {
func doWithFallback(clt *http.Client, allowPlainHTTP bool, extraHeaders map[string]string, req *http.Request) (*http.Response, error) {
// first try https and see how that goes
req.URL.Scheme = "https"
for k, v := range extraHeaders {
req.Header.Add(k, v)
}

log.Debugf("Attempting %s %s%s", req.Method, req.URL.Host, req.URL.Path)
resp, err := clt.Do(req)

Expand Down Expand Up @@ -88,18 +101,18 @@ func doWithFallback(clt *http.Client, allowPlainHTTP bool, req *http.Request) (*

// Find fetches discovery data by connecting to the given web proxy address.
// It is designed to fetch proxy public addresses without any inefficiencies.
func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*PingResponse, error) {
clt := newWebClient(insecure, pool)
func Find(cfg *Config) (*PingResponse, error) {
clt := newWebClient(cfg)
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/find", proxyAddr)
endpoint := fmt.Sprintf("https://%s/webapi/find", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := doWithFallback(clt, insecure, req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -118,21 +131,21 @@ func Find(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP
// errors before being asked for passwords. The second is to return the form
// of authentication that the server supports. This also leads to better user
// experience: users only get prompted for the type of authentication the server supports.
func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool, connectorName string) (*PingResponse, error) {
clt := newWebClient(insecure, pool)
func Ping(cfg *Config) (*PingResponse, error) {
clt := newWebClient(cfg)
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/ping", proxyAddr)
if connectorName != "" {
endpoint = fmt.Sprintf("%s/%s", endpoint, connectorName)
endpoint := fmt.Sprintf("https://%s/webapi/ping", cfg.ProxyAddr)
if cfg.ConnectorName != "" {
endpoint = fmt.Sprintf("%s/%s", endpoint, cfg.ConnectorName)
}

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := doWithFallback(clt, insecure, req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -147,32 +160,32 @@ func Ping(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertP
}

// GetTunnelAddr returns the tunnel address either set in an environment variable or retrieved from the web proxy.
func GetTunnelAddr(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (string, error) {
func GetTunnelAddr(cfg *Config) (string, error) {
// If TELEPORT_TUNNEL_PUBLIC_ADDR is set, nothing else has to be done, return it.
if tunnelAddr := os.Getenv(defaults.TunnelPublicAddrEnvar); tunnelAddr != "" {
return extractHostPort(tunnelAddr)
}

// Ping web proxy to retrieve tunnel proxy address.
pr, err := Find(ctx, proxyAddr, insecure, nil)
pr, err := Find(cfg)
if err != nil {
return "", trace.Wrap(err)
}
return tunnelAddr(proxyAddr, pr.Proxy)
return tunnelAddr(cfg.ProxyAddr, pr.Proxy)
}

func GetMOTD(ctx context.Context, proxyAddr string, insecure bool, pool *x509.CertPool) (*MotD, error) {
clt := newWebClient(insecure, pool)
func GetMOTD(cfg *Config) (*MotD, error) {
clt := newWebClient(cfg)
defer clt.CloseIdleConnections()

endpoint := fmt.Sprintf("https://%s/webapi/motd", proxyAddr)
endpoint := fmt.Sprintf("https://%s/webapi/motd", cfg.ProxyAddr)

req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
req, err := http.NewRequestWithContext(cfg.Context, http.MethodGet, endpoint, nil)
if err != nil {
return nil, trace.Wrap(err)
}

resp, err := clt.Do(req)
resp, err := doWithFallback(clt, cfg.Insecure, cfg.ExtraHeaders, req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 4 additions & 3 deletions api/client/webclient/webclient_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,15 @@ func TestPlainHttpFallback(t *testing.T) {
desc: "Ping",
handler: newPingHandler("/webapi/ping"),
actionUnderTest: func(addr string, insecure bool) error {
_, err := Ping(context.Background(), addr, insecure, nil /*pool*/, "")
_, err := Ping(
&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
return err
},
}, {
desc: "Find",
handler: newPingHandler("/webapi/find"),
actionUnderTest: func(addr string, insecure bool) error {
_, err := Find(context.Background(), addr, insecure, nil /*pool*/)
_, err := Find(&Config{Context: context.Background(), ProxyAddr: addr, Insecure: insecure})
return err
},
},
Expand Down Expand Up @@ -104,7 +105,7 @@ func TestPlainHttpFallback(t *testing.T) {

func TestGetTunnelAddr(t *testing.T) {
t.Setenv(defaults.TunnelPublicAddrEnvar, "tunnel.example.com:4024")
tunnelAddr, err := GetTunnelAddr(context.Background(), "", true, nil)
tunnelAddr, err := GetTunnelAddr(&Config{Context: context.Background(), ProxyAddr: "", Insecure: false})
require.NoError(t, err)
require.Equal(t, "tunnel.example.com:4024", tunnelAddr)
}
Expand Down
27 changes: 17 additions & 10 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,9 @@ type Config struct {

// Invited is a list of people invited to a session.
Invited []string

// ExtraProxyHeaders is a collection of http headers to be included in requests to the WebProxy.
ExtraProxyHeaders map[string]string
}

// CachePolicy defines cache policy for local clients
Expand Down Expand Up @@ -2544,12 +2547,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er
if tc.lastPing != nil {
return tc.lastPing, nil
}
pr, err := webclient.Ping(
ctx,
tc.WebProxyAddr,
tc.InsecureSkipVerify,
loopbackPool(tc.WebProxyAddr),
tc.AuthConnector)
pr, err := webclient.Ping(&webclient.Config{
Context: ctx,
ProxyAddr: tc.WebProxyAddr,
Insecure: tc.InsecureSkipVerify,
Pool: loopbackPool(tc.WebProxyAddr),
ConnectorName: tc.AuthConnector,
ExtraHeaders: tc.ExtraProxyHeaders})
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -2581,10 +2585,13 @@ func (tc *TeleportClient) Ping(ctx context.Context) (*webclient.PingResponse, er
// confirmation from the user.
func (tc *TeleportClient) ShowMOTD(ctx context.Context) error {
motd, err := webclient.GetMOTD(
ctx,
tc.WebProxyAddr,
tc.InsecureSkipVerify,
loopbackPool(tc.WebProxyAddr))
&webclient.Config{
Context: ctx,
ProxyAddr: tc.WebProxyAddr,
Insecure: tc.InsecureSkipVerify,
Pool: loopbackPool(tc.WebProxyAddr),
ExtraHeaders: tc.ExtraProxyHeaders})

if err != nil {
return trace.Wrap(err)
}
Expand Down
24 changes: 23 additions & 1 deletion lib/client/keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ const (
// keyFilePerms is the default permissions applied to key files (.cert, .key, pub)
// under ~/.tsh
keyFilePerms os.FileMode = 0600

// tshConfigFileName is the name of the directory containing the
// tsh config file.
tshConfigFileName = "config"
)

// LocalKeyStore interface allows for different storage backends for tsh to
Expand Down Expand Up @@ -223,9 +227,27 @@ func (fs *FSLocalKeyStore) DeleteUserCerts(idx KeyIndex, opts ...CertOption) err

// DeleteKeys removes all session keys.
func (fs *FSLocalKeyStore) DeleteKeys() error {
if err := os.RemoveAll(fs.KeyDir); err != nil {

files, err := os.ReadDir(fs.KeyDir)
if err != nil {
return trace.ConvertSystemError(err)
}
for _, file := range files {
if file.IsDir() && file.Name() == tshConfigFileName {
continue
}
if file.IsDir() {
err := os.RemoveAll(filepath.Join(fs.KeyDir, file.Name()))
if err != nil {
return trace.ConvertSystemError(err)
}
continue
}
err := os.Remove(filepath.Join(fs.KeyDir, file.Name()))
if err != nil {
return trace.ConvertSystemError(err)
}
}
return nil
}

Expand Down
13 changes: 13 additions & 0 deletions lib/client/keystore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,19 @@ func TestAddKey_withoutSSHCert(t *testing.T) {
require.Len(t, keyCopy.DBTLSCerts, 1)
}

func TestConfigDirNotDeleted(t *testing.T) {
s, cleanup := newTest(t)
t.Cleanup(cleanup)
idx := KeyIndex{"host.a", "bob", "root"}
s.store.AddKey(s.makeSignedKey(t, idx, false))
configPath := filepath.Join(s.storeDir, "config")
require.NoError(t, os.Mkdir(configPath, 0700))
require.NoError(t, s.store.DeleteKeys())
require.DirExists(t, configPath)

require.NoDirExists(t, filepath.Join(s.storeDir, "keys"))
}

type keyStoreTest struct {
storeDir string
store *FSLocalKeyStore
Expand Down
4 changes: 3 additions & 1 deletion lib/reversetunnel/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ func (a *Agent) getHostCheckers() ([]ssh.PublicKey, error) {
// If this is Web Service port check if proxy support ALPN SNI Listener.
func (a *Agent) getReverseTunnelDetails() *reverseTunnelDetails {
pd := reverseTunnelDetails{TLSRoutingEnabled: false}
resp, err := webclient.Find(a.ctx, a.Addr.Addr, lib.IsInsecureDevMode(), nil)
resp, err := webclient.Find(
&webclient.Config{Context: a.ctx, ProxyAddr: a.Addr.Addr, Insecure: lib.IsInsecureDevMode()})

if err != nil {
// If TLS Routing is disabled the address is the proxy reverse tunnel
// address the ping call will always fail.
Expand Down
4 changes: 3 additions & 1 deletion lib/reversetunnel/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ func WebClientResolver(ctx context.Context, addrs []utils.NetAddr, insecureTLS b
for _, addr := range addrs {
// In insecure mode, any certificate is accepted. In secure mode the hosts
// CAs are used to validate the certificate on the proxy.
tunnelAddr, err := webclient.GetTunnelAddr(ctx, addr.String(), insecureTLS, nil)
tunnelAddr, err := webclient.GetTunnelAddr(
&webclient.Config{Context: ctx, ProxyAddr: addr.String(), Insecure: insecureTLS})

if err != nil {
errs = append(errs, err)
continue
Expand Down
3 changes: 2 additions & 1 deletion lib/reversetunnel/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func (t *TunnelAuthDialer) DialContext(ctx context.Context, _, _ string) (net.Co
}

// Check if t.ProxyAddr is ProxyWebPort and remote Proxy supports TLS ALPNSNIListener.
resp, err := webclient.Find(ctx, addr.Addr, t.InsecureSkipTLSVerify, nil)
resp, err := webclient.Find(
&webclient.Config{Context: ctx, ProxyAddr: addr.Addr, Insecure: t.InsecureSkipTLSVerify})
if err != nil {
// If TLS Routing is disabled the address is the proxy reverse tunnel
// address thus the ping call will always fail.
Expand Down
1 change: 0 additions & 1 deletion lib/web/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,6 @@ func defaultAuthenticationSettings(ctx context.Context, authClient auth.ClientI)

func (h *Handler) ping(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) {
var err error

defaultSettings, err := defaultAuthenticationSettings(r.Context(), h.cfg.ProxyClient)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
30 changes: 30 additions & 0 deletions tool/tsh/tsh.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"os/signal"
"path"
"path/filepath"
"regexp"
"runtime"
"sort"
"strings"
Expand All @@ -39,6 +40,7 @@ import (
"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/types"
apiutils "github.com/gravitational/teleport/api/utils"
apisshutils "github.com/gravitational/teleport/api/utils/sshutils"
Expand Down Expand Up @@ -283,6 +285,9 @@ type CLIConf struct {

// JoinMode is the participant mode someone is joining a session as.
JoinMode string

// ExtraProxyHeaders is configuration read from the .tsh/config/config.yaml file.
ExtraProxyHeaders []ExtraProxyHeaders
}

// Stdout returns the stdout writer.
Expand Down Expand Up @@ -645,6 +650,15 @@ func Run(args []string, opts ...cliOption) error {

setEnvFlags(&cf, os.Getenv)

confOptions, err := loadConfig(cf.HomePath)
if err != nil && !trace.IsNotFound(err) {
return trace.Wrap(err, "failed to load tsh config from %s",
filepath.Join(profile.FullProfilePath(cf.HomePath), tshConfigPath))
}
if confOptions != nil {
cf.ExtraProxyHeaders = confOptions.ExtraHeaders
}

switch command {
case ver.FullCommand():
utils.PrintVersion()
Expand Down Expand Up @@ -1964,6 +1978,22 @@ func makeClient(cf *CLIConf, useProfileLogin bool) (*client.TeleportClient, erro
return nil, trace.Wrap(err)
}

if c.ExtraProxyHeaders == nil {
c.ExtraProxyHeaders = map[string]string{}
}
for _, proxyHeaders := range cf.ExtraProxyHeaders {
proxyGlob := utils.GlobToRegexp(proxyHeaders.Proxy)
proxyRegexp, err := regexp.Compile(proxyGlob)
if err != nil {
return nil, trace.WrapWithMessage(err, "invalid proxy glob %q in tsh configuration file", proxyGlob)
}
if proxyRegexp.MatchString(c.WebProxyAddr) {
for k, v := range proxyHeaders.Headers {
c.ExtraProxyHeaders[k] = v
}
}
}

if len(fPorts) > 0 {
c.LocalForwardPorts = fPorts
}
Expand Down
Loading

0 comments on commit 5210182

Please sign in to comment.