diff --git a/lib/service/service.go b/lib/service/service.go index 64a5113dfe779..1c249250d1e21 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -5190,15 +5190,26 @@ func (process *TeleportProcess) singleProcessMode(mode types.ProxyListenerMode) } if !process.Config.Proxy.DisableTLS && !process.Config.Proxy.DisableALPNSNIListener && mode == types.ProxyListenerMode_Multiplex { - if len(process.Config.Proxy.PublicAddrs) != 0 { - return &process.Config.Proxy.PublicAddrs[0], true - } + var addr utils.NetAddr + switch { + // Use the public address if available. + case len(process.Config.Proxy.PublicAddrs) != 0: + addr = process.Config.Proxy.PublicAddrs[0] + // If WebAddress is unspecified "0.0.0.0" replace 0.0.0.0 with localhost since 0.0.0.0 is never a valid // principal (auth server explicitly removes it when issuing host certs) and when WebPort is used // in the single process mode to establish SSH reverse tunnel connection the host is validated against // the valid principal list. - addr := process.Config.Proxy.WebAddr - addr.Addr = utils.ReplaceUnspecifiedHost(&addr, defaults.HTTPListenPort) + default: + addr = process.Config.Proxy.WebAddr + addr.Addr = utils.ReplaceUnspecifiedHost(&addr, defaults.HTTPListenPort) + } + + // In case the address has "https" scheme for TLS Routing, make sure + // "tcp" is used when dialing reverse tunnel. + if addr.AddrNetwork == "https" { + addr.AddrNetwork = "tcp" + } return &addr, true } diff --git a/lib/service/service_test.go b/lib/service/service_test.go index 7539675e22cf5..0435e6e65f7e0 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -1156,3 +1156,121 @@ func TestEnterpriseServicesEnabled(t *testing.T) { }) } } + +func TestSingleProcessModeResolver(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + mode types.ProxyListenerMode + config servicecfg.Config + wantError bool + wantAddr string + }{ + { + name: "not single process mode", + mode: types.ProxyListenerMode_Separate, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + }, + Auth: servicecfg.AuthConfig{ + Enabled: false, + }, + }, + wantError: true, + }, + { + name: "reverse tunnel disabled", + mode: types.ProxyListenerMode_Separate, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + DisableReverseTunnel: true, + }, + Auth: servicecfg.AuthConfig{ + Enabled: true, + }, + }, + wantError: true, + }, + { + name: "separate port localhost", + mode: types.ProxyListenerMode_Separate, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + }, + Auth: servicecfg.AuthConfig{ + Enabled: true, + }, + }, + wantAddr: "tcp://localhost:3024", + }, + { + name: "separate port tunnel addr", + mode: types.ProxyListenerMode_Separate, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + TunnelPublicAddrs: []utils.NetAddr{ + *utils.MustParseAddr("example.com:12345"), + *utils.MustParseAddr("example.org:12345"), + }, + }, + Auth: servicecfg.AuthConfig{ + Enabled: true, + }, + }, + wantAddr: "tcp://example.com:12345", + }, + { + name: "multiplex public addr", + mode: types.ProxyListenerMode_Multiplex, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + PublicAddrs: []utils.NetAddr{ + *utils.MustParseAddr("example.com:12345"), + *utils.MustParseAddr("example.org:12345"), + }, + }, + Auth: servicecfg.AuthConfig{ + Enabled: true, + }, + }, + wantAddr: "tcp://example.com:12345", + }, + { + name: "multiplex web addr with https scheme", + mode: types.ProxyListenerMode_Multiplex, + config: servicecfg.Config{ + Proxy: servicecfg.ProxyConfig{ + Enabled: true, + WebAddr: *utils.MustParseAddr("https://example.com:12345"), + }, + Auth: servicecfg.AuthConfig{ + Enabled: true, + }, + }, + wantAddr: "tcp://example.com:12345", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + process := TeleportProcess{Config: &test.config} + resolver := process.SingleProcessModeResolver(test.mode) + require.NotNil(t, resolver) + addr, mode, err := resolver(context.Background()) + if test.wantError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, mode, test.mode) + require.Equal(t, addr.FullAddress(), test.wantAddr) + }) + } +}