diff --git a/pkg/agent/config/config.go b/pkg/agent/config/config.go index 4883297713ad..26043c8bd8e4 100644 --- a/pkg/agent/config/config.go +++ b/pkg/agent/config/config.go @@ -91,15 +91,24 @@ func KubeProxyDisabled(ctx context.Context, node *config.Node, proxy proxy.Proxy return disabled } -// APIServers returns a list of apiserver endpoints, suitable for seeding client loadbalancer configurations. +// WaitForAPIServers returns a list of apiserver endpoints, suitable for seeding client loadbalancer configurations. // This function will block until it can return a populated list of apiservers, or if the remote server returns // an error (indicating that it does not support this functionality). -func APIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) []string { +func WaitForAPIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) []string { var addresses []string + var info *clientaccess.Info var err error _ = wait.PollUntilContextCancel(ctx, 5*time.Second, true, func(ctx context.Context) (bool, error) { - addresses, err = getAPIServers(ctx, node, proxy) + if info == nil { + withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) + info, err = clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) + if err != nil { + logrus.Warnf("Failed to validate server token: %v", err) + return false, nil + } + } + addresses, err = GetAPIServers(ctx, info) if err != nil { logrus.Infof("Failed to retrieve list of apiservers from server: %v", err) return false, err @@ -760,14 +769,8 @@ func get(ctx context.Context, envInfo *cmds.Agent, proxy proxy.Proxy) (*config.N return nodeConfig, nil } -// getAPIServers attempts to return a list of apiservers from the server. -func getAPIServers(ctx context.Context, node *config.Node, proxy proxy.Proxy) ([]string, error) { - withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) - info, err := clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) - if err != nil { - return nil, err - } - +// GetAPIServers attempts to return a list of apiservers from the server. +func GetAPIServers(ctx context.Context, info *clientaccess.Info) ([]string, error) { data, err := info.Get("/v1-" + version.Program + "/apiservers") if err != nil { return nil, err diff --git a/pkg/agent/loadbalancer/config.go b/pkg/agent/loadbalancer/config.go index 1620c8ab6bbc..b7d8f63f9d10 100644 --- a/pkg/agent/loadbalancer/config.go +++ b/pkg/agent/loadbalancer/config.go @@ -7,8 +7,18 @@ import ( "github.com/k3s-io/k3s/pkg/agent/util" ) +// lbConfig stores loadbalancer state that should be persisted across restarts. +type lbConfig struct { + ServerURL string `json:"ServerURL"` + ServerAddresses []string `json:"ServerAddresses"` +} + func (lb *LoadBalancer) writeConfig() error { - configOut, err := json.MarshalIndent(lb, "", " ") + config := &lbConfig{ + ServerURL: lb.scheme + "://" + lb.servers.getDefaultAddress(), + ServerAddresses: lb.servers.getAddresses(), + } + configOut, err := json.MarshalIndent(config, "", " ") if err != nil { return err } @@ -16,20 +26,17 @@ func (lb *LoadBalancer) writeConfig() error { } func (lb *LoadBalancer) updateConfig() error { - writeConfig := true if configBytes, err := os.ReadFile(lb.configFile); err == nil { - config := &LoadBalancer{} + config := &lbConfig{} if err := json.Unmarshal(configBytes, config); err == nil { - if config.ServerURL == lb.ServerURL { - writeConfig = false - lb.setServers(config.ServerAddresses) + // if the default server from the config matches our current default, + // load the rest of the addresses as well. + if config.ServerURL == lb.scheme+"://"+lb.servers.getDefaultAddress() { + lb.Update(config.ServerAddresses) + return nil } } } - if writeConfig { - if err := lb.writeConfig(); err != nil { - return err - } - } - return nil + // config didn't exist or used a different default server, write the current config to disk. + return lb.writeConfig() } diff --git a/pkg/agent/loadbalancer/httpproxy.go b/pkg/agent/loadbalancer/httpproxy.go new file mode 100644 index 000000000000..ea9711824975 --- /dev/null +++ b/pkg/agent/loadbalancer/httpproxy.go @@ -0,0 +1,70 @@ +package loadbalancer + +import ( + "fmt" + "net" + "net/url" + "os" + "strconv" + "time" + + "github.com/k3s-io/k3s/pkg/version" + http_dialer "github.com/mwitkow/go-http-dialer" + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "golang.org/x/net/http/httpproxy" + "golang.org/x/net/proxy" +) + +var defaultDialer proxy.Dialer = &net.Dialer{ + Timeout: 10 * time.Second, + KeepAlive: 30 * time.Second, +} + +// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, +// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured +// to indicate use of a HTTP proxy for the server URL. +func SetHTTPProxy(address string) error { + // Check if env variable for proxy is set + if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy || address == "" { + return nil + } + + serverURL, err := url.Parse(address) + if err != nil { + return errors.Wrapf(err, "failed to parse address %s", address) + } + + // Call this directly instead of using the cached environment used by http.ProxyFromEnvironment to allow for testing + proxyFromEnvironment := httpproxy.FromEnvironment().ProxyFunc() + proxyURL, err := proxyFromEnvironment(serverURL) + if err != nil { + return errors.Wrapf(err, "failed to get proxy for address %s", address) + } + if proxyURL == nil { + logrus.Debug(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED is true but no proxy is configured for URL " + serverURL.String()) + return nil + } + + dialer, err := proxyDialer(proxyURL, defaultDialer) + if err != nil { + return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) + } + + defaultDialer = dialer + logrus.Debugf("Using proxy %s for agent connection to %s", proxyURL, serverURL) + return nil +} + +// proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. +func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { + if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + // Create a new HTTP proxy dialer + httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithConnectionTimeout(10*time.Second), http_dialer.WithDialer(forward.(*net.Dialer))) + return httpProxyDialer, nil + } else if proxyURL.Scheme == "socks5" { + // For SOCKS5 proxies, use the proxy package's FromURL + return proxy.FromURL(proxyURL, forward) + } + return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) +} diff --git a/pkg/agent/loadbalancer/servers_test.go b/pkg/agent/loadbalancer/httpproxy_test.go similarity index 97% rename from pkg/agent/loadbalancer/servers_test.go rename to pkg/agent/loadbalancer/httpproxy_test.go index c8b8b5b924bb..07f72e927e77 100644 --- a/pkg/agent/loadbalancer/servers_test.go +++ b/pkg/agent/loadbalancer/httpproxy_test.go @@ -2,15 +2,16 @@ package loadbalancer import ( "fmt" - "net" "os" "strings" "testing" "github.com/k3s-io/k3s/pkg/version" "github.com/sirupsen/logrus" + "golang.org/x/net/proxy" ) +var originalDialer proxy.Dialer var defaultEnv map[string]string var proxyEnvs = []string{version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED", "HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY", "http_proxy", "https_proxy", "no_proxy"} @@ -19,7 +20,7 @@ func init() { } func prepareEnv(env ...string) { - defaultDialer = &net.Dialer{} + originalDialer = defaultDialer defaultEnv = map[string]string{} for _, e := range proxyEnvs { if v, ok := os.LookupEnv(e); ok { @@ -34,6 +35,7 @@ func prepareEnv(env ...string) { } func restoreEnv() { + defaultDialer = originalDialer for _, e := range proxyEnvs { if v, ok := defaultEnv[e]; ok { os.Setenv(e, v) diff --git a/pkg/agent/loadbalancer/loadbalancer.go b/pkg/agent/loadbalancer/loadbalancer.go index 6689a9e7ca39..09727db18922 100644 --- a/pkg/agent/loadbalancer/loadbalancer.go +++ b/pkg/agent/loadbalancer/loadbalancer.go @@ -2,12 +2,12 @@ package loadbalancer import ( "context" - "errors" "fmt" "net" + "net/url" "os" "path/filepath" - "sync" + "strings" "time" "github.com/inetaf/tcpproxy" @@ -15,43 +15,17 @@ import ( "github.com/sirupsen/logrus" ) -// server tracks the connections to a server, so that they can be closed when the server is removed. -type server struct { - // This mutex protects access to the connections map. All direct access to the map should be protected by it. - mutex sync.Mutex - address string - healthCheck func() bool - connections map[net.Conn]struct{} -} - -// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. -type serverConn struct { - server *server - net.Conn -} - // LoadBalancer holds data for a local listener which forwards connections to a // pool of remote servers. It is not a proper load-balancer in that it does not // actually balance connections, but instead fails over to a new server only // when a connection attempt to the currently selected server fails. type LoadBalancer struct { - // This mutex protects access to servers map and randomServers list. - // All direct access to the servers map/list should be protected by it. - mutex sync.RWMutex - proxy *tcpproxy.Proxy - - serviceName string - configFile string - localAddress string - localServerURL string - defaultServerAddress string - ServerURL string - ServerAddresses []string - randomServers []string - servers map[string]*server - currentServerAddress string - nextServerIndex int - Listener net.Listener + serviceName string + configFile string + scheme string + localAddress string + servers serverList + proxy *tcpproxy.Proxy } const RandomPort = 0 @@ -64,7 +38,7 @@ var ( // New contstructs a new LoadBalancer instance. The default server URL, and // currently active servers, are stored in a file within the dataDir. -func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { +func New(ctx context.Context, dataDir, serviceName, defaultServerURL string, lbServerPort int, isIPv6 bool) (_lb *LoadBalancer, _err error) { config := net.ListenConfig{Control: reusePort} var localAddress string if isIPv6 { @@ -85,30 +59,35 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo return nil, err } - // if lbServerPort was 0, the port was assigned by the OS when bound - see what we ended up with. - localAddress = listener.Addr().String() - - defaultServerAddress, localServerURL, err := parseURL(serverURL, localAddress) + serverURL, err := url.Parse(defaultServerURL) if err != nil { return nil, err } - if serverURL == localServerURL { - logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) - defaultServerAddress = "" + // Set explicit port from scheme + if serverURL.Port() == "" { + if strings.ToLower(serverURL.Scheme) == "http" { + serverURL.Host += ":80" + } + if strings.ToLower(serverURL.Scheme) == "https" { + serverURL.Host += ":443" + } } lb := &LoadBalancer{ - serviceName: serviceName, - configFile: filepath.Join(dataDir, "etc", serviceName+".json"), - localAddress: localAddress, - localServerURL: localServerURL, - defaultServerAddress: defaultServerAddress, - servers: make(map[string]*server), - ServerURL: serverURL, + serviceName: serviceName, + configFile: filepath.Join(dataDir, "etc", serviceName+".json"), + scheme: serverURL.Scheme, + localAddress: listener.Addr().String(), } - lb.setServers([]string{lb.defaultServerAddress}) + // if starting pointing at ourselves, don't set a default server address, + // which will cause all dials to fail until servers are added. + if serverURL.Host == lb.localAddress { + logrus.Debugf("Initial server URL for load balancer %s points at local server URL - starting with empty default server address", serviceName) + } else { + lb.servers.setDefaultAddress(lb.serviceName, serverURL.Host) + } lb.proxy = &tcpproxy.Proxy{ ListenFunc: func(string, string) (net.Listener, error) { @@ -117,8 +96,18 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo } lb.proxy.AddRoute(serviceName, &tcpproxy.DialProxy{ Addr: serviceName, - DialContext: lb.dialContext, OnDialError: onDialError, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + start := time.Now() + status := "success" + conn, err := lb.servers.dialContext(ctx, network, address) + latency := time.Since(start) + if err != nil { + status = "error" + } + loadbalancerDials.WithLabelValues(serviceName, status).Observe(latency.Seconds()) + return conn, err + }, }) if err := lb.updateConfig(); err != nil { @@ -127,85 +116,50 @@ func New(ctx context.Context, dataDir, serviceName, serverURL string, lbServerPo if err := lb.proxy.Start(); err != nil { return nil, err } - logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.ServerAddresses, lb.defaultServerAddress) + logrus.Infof("Running load balancer %s %s -> %v [default: %s]", serviceName, lb.localAddress, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) - go lb.runHealthChecks(ctx) + go lb.servers.runHealthChecks(ctx, lb.serviceName) return lb, nil } +// Update updates the list of server addresses to contain only the listed servers. func (lb *LoadBalancer) Update(serverAddresses []string) { - if lb == nil { + if !lb.servers.setAddresses(lb.serviceName, serverAddresses) { return } - if !lb.setServers(serverAddresses) { - return - } - logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.ServerAddresses, lb.defaultServerAddress) + + logrus.Infof("Updated load balancer %s server addresses -> %v [default: %s]", lb.serviceName, lb.servers.getAddresses(), lb.servers.getDefaultAddress()) if err := lb.writeConfig(); err != nil { logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } } -func (lb *LoadBalancer) LoadBalancerServerURL() string { - if lb == nil { - return "" +// SetDefault sets the selected address as the default / fallback address +func (lb *LoadBalancer) SetDefault(serverAddress string) { + lb.servers.setDefaultAddress(lb.serviceName, serverAddress) + + if err := lb.writeConfig(); err != nil { + logrus.Warnf("Error updating load balancer %s config: %s", lb.serviceName, err) } - return lb.localServerURL } -func (lb *LoadBalancer) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - - var allChecksFailed bool - startIndex := lb.nextServerIndex - for { - targetServer := lb.currentServerAddress - - server := lb.servers[targetServer] - if server == nil || targetServer == "" { - logrus.Debugf("Nil server for load balancer %s: %s", lb.serviceName, targetServer) - } else if allChecksFailed || server.healthCheck() { - dialTime := time.Now() - conn, err := server.dialContext(ctx, network, targetServer) - if err == nil { - return conn, nil - } - logrus.Debugf("Dial error from load balancer %s after %s: %s", lb.serviceName, time.Now().Sub(dialTime), err) - // Don't close connections to the failed server if we're retrying with health checks ignored. - // We don't want to disrupt active connections if it is unlikely they will have anywhere to go. - if !allChecksFailed { - defer server.closeAll() - } - } else { - logrus.Debugf("Dial health check failed for %s", targetServer) - } +// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. +func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck HealthCheckFunc) { + if err := lb.servers.setHealthCheck(address, healthCheck); err != nil { + logrus.Errorf("Failed to set health check for load balancer %s: %v", lb.serviceName, err) + } else { + logrus.Debugf("Set health check for load balancer %s: %s", lb.serviceName, address) + } +} - newServer, err := lb.nextServer(targetServer) - if err != nil { - return nil, err - } - if targetServer != newServer { - logrus.Debugf("Failed over to new server for load balancer %s: %s -> %s", lb.serviceName, targetServer, newServer) - } - if ctx.Err() != nil { - return nil, ctx.Err() - } +func (lb *LoadBalancer) LocalURL() string { + return lb.scheme + "://" + lb.localAddress +} - maxIndex := len(lb.randomServers) - if startIndex > maxIndex { - startIndex = maxIndex - } - if lb.nextServerIndex == startIndex { - if allChecksFailed { - return nil, errors.New("all servers failed") - } - logrus.Debugf("Health checks for all servers in load balancer %s have failed: retrying with health checks ignored", lb.serviceName) - allChecksFailed = true - } - } +func (lb *LoadBalancer) ServerAddresses() []string { + return lb.servers.getAddresses() } func onDialError(src net.Conn, dstDialErr error) { @@ -214,10 +168,9 @@ func onDialError(src net.Conn, dstDialErr error) { } // ResetLoadBalancer will delete the local state file for the load balancer on disk -func ResetLoadBalancer(dataDir, serviceName string) error { +func ResetLoadBalancer(dataDir, serviceName string) { stateFile := filepath.Join(dataDir, "etc", serviceName+".json") - if err := os.Remove(stateFile); err != nil { + if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) { logrus.Warn(err) } - return nil } diff --git a/pkg/agent/loadbalancer/loadbalancer_test.go b/pkg/agent/loadbalancer/loadbalancer_test.go index cbfdf982c690..69b4fca10cab 100644 --- a/pkg/agent/loadbalancer/loadbalancer_test.go +++ b/pkg/agent/loadbalancer/loadbalancer_test.go @@ -5,19 +5,29 @@ import ( "context" "fmt" "net" - "net/url" + "strconv" "strings" "testing" "time" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" "github.com/sirupsen/logrus" ) +func Test_UnitLoadBalancer(t *testing.T) { + _, reporterConfig := GinkgoConfiguration() + reporterConfig.Verbose = testing.Verbose() + RegisterFailHandler(Fail) + RunSpecs(t, "LoadBalancer Suite", reporterConfig) +} + func init() { logrus.SetLevel(logrus.DebugLevel) } type testServer struct { + address string listener net.Listener conns []net.Conn prefix string @@ -31,6 +41,7 @@ func createServer(ctx context.Context, prefix string) (*testServer, error) { s := &testServer{ prefix: prefix, listener: listener, + address: listener.Addr().String(), } go s.serve() go func() { @@ -53,6 +64,7 @@ func (s *testServer) serve() { func (s *testServer) close() { logrus.Printf("testServer %s closing", s.prefix) + s.address = "" s.listener.Close() for _, conn := range s.conns { conn.Close() @@ -69,10 +81,6 @@ func (s *testServer) echo(conn net.Conn) { } } -func (s *testServer) address() string { - return s.listener.Addr().String() -} - func ping(conn net.Conn) (string, error) { fmt.Fprintf(conn, "ping\n") result, err := bufio.NewReader(conn).ReadString('\n') @@ -82,221 +90,340 @@ func ping(conn net.Conn) (string, error) { return strings.TrimSpace(result), nil } -// Test_UnitFailOver creates a LB using a default server (ie fixed registration endpoint) -// and then adds a new server (a node). The node server is then closed, and it is confirmed -// that new connections use the default server. -func Test_UnitFailOver(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - defaultServer, err := createServer(ctx, "default") - if err != nil { - t.Fatalf("createServer(default) failed: %v", err) - } - - node1Server, err := createServer(ctx, "node1") - if err != nil { - t.Fatalf("createServer(node1) failed: %v", err) - } +var _ = Describe("LoadBalancer", func() { + // creates a LB using a default server (ie fixed registration endpoint) + // and then adds a new server (a node). The node server is then closed, and it is confirmed + // that new connections use the default server. + When("loadbalancer is running", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer, node1Server, node2Server *testServer + var conn1, conn2, conn3, conn4 net.Conn + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") + + node1Server, err = createServer(ctx, "node1") + Expect(err).NotTo(HaveOccurred(), "createServer(node1) failed") + + node2Server, err = createServer(ctx, "node2") + Expect(err).NotTo(HaveOccurred(), "createServer(node2) failed") + + // start the loadbalancer with the default server as the only server + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address, RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("adds node1 as a server", func() { + // add the node as a new server address. + lb.Update([]string{node1Server.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + + By(fmt.Sprintf("Added node1 server: %v", lb.servers.getServers())) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) + + It("connects to node1", func() { + // make sure connections go to the node + conn1, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn1)).To(Equal("node1:ping"), "Unexpected ping(conn1) result") - node2Server, err := createServer(ctx, "node2") - if err != nil { - t.Fatalf("createServer(node2) failed: %v", err) - } - - // start the loadbalancer with the default server as the only server - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+defaultServer.address(), RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - parsedURL, err := url.Parse(lb.LoadBalancerServerURL()) - if err != nil { - t.Fatalf("url.Parse failed: %v", err) - } - localAddress := parsedURL.Host - - // add the node as a new server address. - lb.Update([]string{node1Server.address()}) - - // make sure connections go to the node - conn1, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - if result, err := ping(conn1); err != nil { - t.Fatalf("ping(conn1) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn1) result: %v", result) - } - - t.Log("conn1 tested OK") - - // set failing health check for node 1 - lb.SetHealthCheck(node1Server.address(), func() bool { return false }) - - // Server connections are checked every second, now that node 1 is failed - // the connections to it should be closed. - time.Sleep(2 * time.Second) - - if _, err := ping(conn1); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn1") - } - - t.Log("conn1 closed on failure OK") - - // make sure connection still goes to the first node - it is failing health checks but so - // is the default endpoint, so it should be tried first with health checks disabled, - // before failing back to the default. - conn2, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } - - t.Log("conn2 tested OK") - - // make sure the health checks don't close the connection we just made - - // connections should only be closed when it transitions from health to unhealthy. - time.Sleep(2 * time.Second) - - if result, err := ping(conn2); err != nil { - t.Fatalf("ping(conn2) failed: %v", err) - } else if result != "node1:ping" { - t.Fatalf("Unexpected ping(conn2) result: %v", result) - } - - t.Log("conn2 tested OK again") - - // shut down the first node server to force failover to the default - node1Server.close() - - // make sure new connections go to the default, and existing connections are closed - conn3, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn3); err != nil { - t.Fatalf("ping(conn3) failed: %v", err) - } else if result != "default:ping" { - t.Fatalf("Unexpected ping(conn3) result: %v", result) - } + By("conn1 tested OK") + }) + + It("changes node1 state to failed", func() { + // set failing health check for node 1 + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultFailed }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node1Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(stateFailed)) + }) + + It("disconnects from node1", func() { + // Server connections are checked every second, now that node 1 is failed + // the connections to it should be closed. + Expect(ping(conn1)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn1 closed on failure OK") + + // connections shoould go to the default now that node 1 is failed + conn2, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") - t.Log("conn3 tested OK") + By("conn2 tested OK") + }) + + It("does not close connections unexpectedly", func() { + // make sure the health checks don't close the connection we just made - + // connections should only be closed when it transitions from health to unhealthy. + time.Sleep(2 * time.Second) + + Expect(ping(conn2)).To(Equal("default:ping"), "Unexpected ping(conn2) result") + + By("conn2 tested OK again") + }) + + It("closes connections when dial fails", func() { + // shut down the first node server to force failover to the default + node1Server.close() + + // make sure new connections go to the default, and existing connections are closed + conn3, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + Expect(ping(conn3)).To(Equal("default:ping"), "Unexpected ping(conn3) result") + + By("conn3 tested OK") + }) - if _, err := ping(conn2); err == nil { - t.Fatal("Unexpected successful ping on closed connection conn2") - } - - t.Log("conn2 closed on failure OK") - - // add the second node as a new server address. - lb.Update([]string{node2Server.address()}) - - // make sure connection now goes to the second node, - // and connections to the default are closed. - conn4, err := net.Dial("tcp", localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - - } - if result, err := ping(conn4); err != nil { - t.Fatalf("ping(conn4) failed: %v", err) - } else if result != "node2:ping" { - t.Fatalf("Unexpected ping(conn4) result: %v", result) - } - - t.Log("conn4 tested OK") - - // Server connections are checked every second, now that we have a healthy - // server, connections to the default server should be closed - time.Sleep(2 * time.Second) - - if _, err := ping(conn3); err == nil { - t.Fatal("Unexpected successful ping on connection conn3") - } - - t.Log("conn3 closed on failure OK") -} - -// Test_UnitFailFast confirms that connnections to invalid addresses fail quickly -func Test_UnitFailFast(t *testing.T) { - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverURL := "http://127.0.0.1:0/" - lb, err := New(ctx, tmpDir, SupervisorServiceName, serverURL, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(10 * time.Millisecond) - - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from invalid address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} - -// Test_UnitFailUnreachable confirms that connnections to unreachable addresses do fail -// within the expected duration -func Test_UnitFailUnreachable(t *testing.T) { - if testing.Short() { - t.Skip("skipping slow test in short mode.") - } - tmpDir := t.TempDir() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - serverAddr := "192.0.2.1:6443" - lb, err := New(ctx, tmpDir, SupervisorServiceName, "http://"+serverAddr, RandomPort, false) - if err != nil { - t.Fatalf("New() failed: %v", err) - } - - // Set failing health check to reduce retries - lb.SetHealthCheck(serverAddr, func() bool { return false }) - - conn, err := net.Dial("tcp", lb.localAddress) - if err != nil { - t.Fatalf("net.Dial failed: %v", err) - } - - done := make(chan error) - go func() { - _, err = ping(conn) - done <- err - }() - timeout := time.After(11 * time.Second) - - select { - case err := <-done: - if err == nil { - t.Fatal("Unexpected successful ping from unreachable address") - } - case <-timeout: - t.Fatal("Test timed out") - } -} + It("replaces node2 as a server", func() { + // add the second node as a new server address. + lb.Update([]string{node2Server.address}) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + + By(fmt.Sprintf("Added node2 server: %v", lb.servers.getServers())) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(node2Server.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + }) + + It("connects to node2", func() { + // make sure connection now goes to the second node, + // and connections to the default are closed. + conn4, err = net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + Expect(ping(conn4)).To(Equal("node2:ping"), "Unexpected ping(conn3) result") + + By("conn4 tested OK") + }) + + It("does not close connections unexpectedly", func() { + // Server connections are checked every second, now that we have a healthy + // server, connections to the default server should be closed + time.Sleep(2 * time.Second) + + Expect(ping(conn2)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn2 closed on failure OK") + + Expect(ping(conn3)).Error().To(HaveOccurred(), "Unexpected successful ping on closed connection conn1") + + By("conn3 closed on failure OK") + }) + + It("adds default as a server", func() { + // add the default as a full server + lb.Update([]string{node2Server.address, defaultServer.address}) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + + By(fmt.Sprintf("Default server added: %v", lb.servers.getServers())) + }) + + It("returns the default server in the address list", func() { + // confirm that both servers are listed in the address list + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address, defaultServer.address)) + + // confirm that the default is still listed as default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + + }) + + It("does not return the default server in the address list after removing it", func() { + // remove the default as a server + lb.Update([]string{node2Server.address}) + By(fmt.Sprintf("Default removed: %v", lb.servers.getServers())) + + // confirm that it is not listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // but is still listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + + It("removes default server when no longer default", func() { + // set node2 as the default + lb.SetDefault(node2Server.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node2Server.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(node2Server.address), "node2 server is not default") + }) + + It("sets all three servers", func() { + // set node2 as the default + lb.SetDefault(defaultServer.address) + By(fmt.Sprintf("Default set: %v", lb.servers.getServers())) + + lb.Update([]string{node1Server.address, node2Server.address, defaultServer.address}) + lb.SetHealthCheck(node1Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(node2Server.address, func() HealthCheckResult { return HealthCheckResultOK }) + lb.SetHealthCheck(defaultServer.address, func() HealthCheckResult { return HealthCheckResultOK }) + + // wait for state to change + Eventually(func() state { + if s := lb.servers.getServer(defaultServer.address); s != nil { + return s.state + } + return stateInvalid + }, 5, 1).Should(Equal(statePreferred)) + + By(fmt.Sprintf("All servers set: %v", lb.servers.getServers())) + + // confirm that it is still listed as a server + Expect(lb.ServerAddresses()).To(ConsistOf(node1Server.address, node2Server.address, defaultServer.address)) + + // and is listed as the default + Expect(lb.servers.getDefaultAddress()).To(Equal(defaultServer.address), "default server is not default") + }) + }) + + // confirms that the loadbalancer will not dial itself + When("the default server is the loadbalancer", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var defaultServer *testServer + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + + defaultServer, err = createServer(ctx, "default") + Expect(err).NotTo(HaveOccurred(), "createServer(default) failed") + address := defaultServer.address + defaultServer.close() + _, port, _ := net.SplitHostPort(address) + intPort, _ := strconv.Atoi(port) + + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://"+address, intPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails immediately", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + _, err = ping(conn) + Expect(err).To(HaveOccurred(), "Unexpected successful ping on failed connection") + }) + }) + + // confirms that connnections to invalid addresses fail quickly + When("there are no valid addresses", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://127.0.0.1:0/", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails fast", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(10 * time.Millisecond) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from invalid address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) + + // confirms that connnections to unreachable addresses do fail within the + // expected duration + When("the server is unreachable", Ordered, func() { + ctx, cancel := context.WithCancel(context.Background()) + var lb *LoadBalancer + var err error + + BeforeAll(func() { + tmpDir := GinkgoT().TempDir() + lb, err = New(ctx, tmpDir, SupervisorServiceName, "http://192.0.2.1:6443", RandomPort, false) + Expect(err).NotTo(HaveOccurred(), "New() failed") + }) + + AfterAll(func() { + cancel() + }) + + It("fails with the correct timeout", func() { + conn, err := net.Dial("tcp", lb.localAddress) + Expect(err).NotTo(HaveOccurred(), "net.Dial failed") + + done := make(chan error) + go func() { + _, err = ping(conn) + done <- err + }() + timeout := time.After(11 * time.Second) + + select { + case err := <-done: + if err == nil { + Fail("Unexpected successful ping from unreachable address") + } + case <-timeout: + Fail("Test timed out") + } + }) + }) +}) diff --git a/pkg/agent/loadbalancer/metrics.go b/pkg/agent/loadbalancer/metrics.go new file mode 100644 index 000000000000..11f27486eda7 --- /dev/null +++ b/pkg/agent/loadbalancer/metrics.go @@ -0,0 +1,30 @@ +package loadbalancer + +import ( + "github.com/k3s-io/k3s/pkg/version" + "github.com/prometheus/client_golang/prometheus" + "k8s.io/component-base/metrics" +) + +var ( + loadbalancerConnections = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: version.Program + "_loadbalancer_server_connections", + Help: "Count of current connections to loadbalancer server", + }, []string{"name", "server"}) + + loadbalancerState = prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: version.Program + "_loadbalancer_server_health", + Help: "Current health value of loadbalancer server", + }, []string{"name", "server"}) + + loadbalancerDials = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: version.Program + "_loadbalancer_dial_duration_seconds", + Help: "Time taken to dial a connection to a backend server", + Buckets: metrics.ExponentialBuckets(0.001, 2, 15), + }, []string{"name", "status"}) +) + +// MustRegister registers loadbalancer metrics +func MustRegister(registerer prometheus.Registerer) { + registerer.MustRegister(loadbalancerConnections, loadbalancerState, loadbalancerDials) +} diff --git a/pkg/agent/loadbalancer/servers.go b/pkg/agent/loadbalancer/servers.go index 660810525470..13334ea881dc 100644 --- a/pkg/agent/loadbalancer/servers.go +++ b/pkg/agent/loadbalancer/servers.go @@ -1,166 +1,424 @@ package loadbalancer import ( + "cmp" "context" + "errors" "fmt" - "math/rand" "net" - "net/url" - "os" "slices" - "strconv" + "sync" "time" - "github.com/k3s-io/k3s/pkg/version" - http_dialer "github.com/mwitkow/go-http-dialer" - "github.com/pkg/errors" - "golang.org/x/net/http/httpproxy" - "golang.org/x/net/proxy" - "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" ) -var defaultDialer proxy.Dialer = &net.Dialer{ - Timeout: 10 * time.Second, - KeepAlive: 30 * time.Second, +type HealthCheckFunc func() HealthCheckResult + +// HealthCheckResult indicates the status of a server health check poll. +// For health-checks that poll in the background, Unknown should be returned +// if a poll has not occurred since the last check. +type HealthCheckResult int + +const ( + HealthCheckResultUnknown HealthCheckResult = iota + HealthCheckResultFailed + HealthCheckResultOK +) + +// serverList tracks potential backend servers for use by a loadbalancer. +type serverList struct { + // This mutex protects access to the server list. All direct access to the list should be protected by it. + mutex sync.Mutex + servers []*server } -// SetHTTPProxy configures a proxy-enabled dialer to be used for all loadbalancer connections, -// if the agent has been configured to allow use of a HTTP proxy, and the environment has been configured -// to indicate use of a HTTP proxy for the server URL. -func SetHTTPProxy(address string) error { - // Check if env variable for proxy is set - if useProxy, _ := strconv.ParseBool(os.Getenv(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED")); !useProxy || address == "" { - return nil +// setServers updates the server list to contain only the selected addresses. +func (sl *serverList) setAddresses(serviceName string, addresses []string) bool { + newAddresses := sets.New(addresses...) + curAddresses := sets.New(sl.getAddresses()...) + if newAddresses.Equal(curAddresses) { + return false } - serverURL, err := url.Parse(address) - if err != nil { - return errors.Wrapf(err, "failed to parse address %s", address) + sl.mutex.Lock() + defer sl.mutex.Unlock() + + var closeAllFuncs []func() + var defaultServer *server + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + defaultServer = sl.servers[i] } - // Call this directly instead of using the cached environment used by http.ProxyFromEnvironment to allow for testing - proxyFromEnvironment := httpproxy.FromEnvironment().ProxyFunc() - proxyURL, err := proxyFromEnvironment(serverURL) - if err != nil { - return errors.Wrapf(err, "failed to get proxy for address %s", address) + // add new servers + for addedAddress := range newAddresses.Difference(curAddresses) { + if defaultServer != nil && defaultServer.address == addedAddress { + // make default server go through the same health check promotions as a new server when added + logrus.Infof("Server %s->%s from add to load balancer %s", defaultServer, stateUnchecked, serviceName) + defaultServer.state = stateUnchecked + defaultServer.lastTransition = time.Now() + } else { + s := newServer(addedAddress, false) + logrus.Infof("Adding server to load balancer %s: %s", serviceName, s.address) + sl.servers = append(sl.servers, s) + } } - if proxyURL == nil { - logrus.Debug(version.ProgramUpper + "_AGENT_HTTP_PROXY_ALLOWED is true but no proxy is configured for URL " + serverURL.String()) - return nil + + // remove old servers + for removedAddress := range curAddresses.Difference(newAddresses) { + if defaultServer != nil && defaultServer.address == removedAddress { + // demote the default server down to standby, instead of deleting it + defaultServer.state = stateStandby + closeAllFuncs = append(closeAllFuncs, defaultServer.closeAll) + } else { + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.address == removedAddress { + logrus.Infof("Removing server from load balancer %s: %s", serviceName, s.address) + // set state to invalid to prevent server from making additional connections + s.state = stateInvalid + closeAllFuncs = append(closeAllFuncs, s.closeAll) + // remove metrics + loadbalancerState.DeleteLabelValues(serviceName, s.address) + loadbalancerConnections.DeleteLabelValues(serviceName, s.address) + return true + } + return false + }) + } } - dialer, err := proxyDialer(proxyURL, defaultDialer) - if err != nil { - return errors.Wrapf(err, "failed to create proxy dialer for %s", proxyURL) + slices.SortFunc(sl.servers, compareServers) + + // Close all connections to servers that were removed + for _, closeAll := range closeAllFuncs { + closeAll() } - defaultDialer = dialer - logrus.Debugf("Using proxy %s for agent connection to %s", proxyURL, serverURL) - return nil + return true +} + +// getAddresses returns the addresses of all servers. +// If the default server is in standby state, indicating it is only present +// because it is the default, it is not returned in this list. +func (sl *serverList) getAddresses() []string { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + addresses := make([]string, 0, len(sl.servers)) + for _, s := range sl.servers { + if s.isDefault && s.state == stateStandby { + continue + } + addresses = append(addresses, s.address) + } + return addresses } -func (lb *LoadBalancer) setServers(serverAddresses []string) bool { - serverAddresses, hasDefaultServer := sortServers(serverAddresses, lb.defaultServerAddress) - if len(serverAddresses) == 0 { +// setDefault sets the server with the provided address as the default server. +// The default flag is cleared on all other servers, and if the server was previously +// only kept in the list because it was the default, it is removed. +func (sl *serverList) setDefaultAddress(serviceName, address string) { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // deal with existing default first + sl.servers = slices.DeleteFunc(sl.servers, func(s *server) bool { + if s.isDefault && s.address != address { + s.isDefault = false + if s.state == stateStandby { + s.state = stateInvalid + defer s.closeAll() + return true + } + } return false + }) + + // update or create server with selected address + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + sl.servers[i].isDefault = true + } else { + sl.servers = append(sl.servers, newServer(address, true)) } - lb.mutex.Lock() - defer lb.mutex.Unlock() + logrus.Infof("Updated load balancer %s default server: %s", serviceName, address) + slices.SortFunc(sl.servers, compareServers) +} - newAddresses := sets.NewString(serverAddresses...) - curAddresses := sets.NewString(lb.ServerAddresses...) - if newAddresses.Equal(curAddresses) { - return false +// getDefault returns the address of the default server. +func (sl *serverList) getDefaultAddress() string { + if s := sl.getDefaultServer(); s != nil { + return s.address } + return "" +} - for addedServer := range newAddresses.Difference(curAddresses) { - logrus.Infof("Adding server to load balancer %s: %s", lb.serviceName, addedServer) - lb.servers[addedServer] = &server{ - address: addedServer, - connections: make(map[net.Conn]struct{}), - healthCheck: func() bool { return true }, - } +// getDefault returns the default server. +func (sl *serverList) getDefaultServer() *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.isDefault }); i != -1 { + return sl.servers[i] } + return nil +} + +// getServers returns a copy of the servers list that can be safely iterated over without holding a lock +func (sl *serverList) getServers() []*server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + return slices.Clone(sl.servers) +} - for removedServer := range curAddresses.Difference(newAddresses) { - server := lb.servers[removedServer] - if server != nil { - logrus.Infof("Removing server from load balancer %s: %s", lb.serviceName, removedServer) - // Defer closing connections until after the new server list has been put into place. - // Closing open connections ensures that anything stuck retrying on a stale server is forced - // over to a valid endpoint. - defer server.closeAll() - // Don't delete the default server from the server map, in case we need to fall back to it. - if removedServer != lb.defaultServerAddress { - delete(lb.servers, removedServer) +// getServer returns the first server with the specified address +func (sl *serverList) getServer(address string) *server { + sl.mutex.Lock() + defer sl.mutex.Unlock() + + if i := slices.IndexFunc(sl.servers, func(s *server) bool { return s.address == address }); i != -1 { + return sl.servers[i] + } + return nil +} + +// setHealthCheck updates the health check function for a server, replacing the +// current function. +func (sl *serverList) setHealthCheck(address string, healthCheck HealthCheckFunc) error { + if s := sl.getServer(address); s != nil { + s.healthCheck = healthCheck + return nil + } + return fmt.Errorf("no server found for %s", address) +} + +// recordSuccess records a successful check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordSuccess(srv *server, r reason) { + var new_state state + switch srv.state { + case stateFailed, stateUnchecked: + // dialed or health checked OK once, improve to recovering + new_state = stateRecovering + case stateRecovering: + if r == reasonHealthCheck { + // was recovering due to successful dial or first health check, can now improve + if len(srv.connections) > 0 { + // server accepted connections while recovering, attempt to go straight to active + new_state = stateActive + } else { + // no connections, just make it preferred + new_state = statePreferred + } + } + case stateHealthy: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } + case statePreferred: + if r == reasonDial { + // improve from healthy to active by being dialed + new_state = stateActive + } else { + if time.Now().Sub(srv.lastTransition) > time.Minute { + // has been preferred for a while without being dialed, demote to healthy + new_state = stateHealthy } } } - lb.ServerAddresses = serverAddresses - lb.randomServers = append([]string{}, lb.ServerAddresses...) - rand.Shuffle(len(lb.randomServers), func(i, j int) { - lb.randomServers[i], lb.randomServers[j] = lb.randomServers[j], lb.randomServers[i] - }) - // If the current server list does not contain the default server address, - // we want to include it in the random server list so that it can be tried if necessary. - // However, it should be treated as always failing health checks so that it is only - // used if all other endpoints are unavailable. - if !hasDefaultServer { - lb.randomServers = append(lb.randomServers, lb.defaultServerAddress) - if defaultServer, ok := lb.servers[lb.defaultServerAddress]; ok { - defaultServer.healthCheck = func() bool { return false } - lb.servers[lb.defaultServerAddress] = defaultServer + // no-op if state did not change + if new_state == stateInvalid { + return + } + + // handle active transition and sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // handle states of other servers when attempting to make this one active + if new_state == stateActive { + for _, s := range sl.servers { + if srv.address == s.address { + continue + } + switch s.state { + case stateFailed, stateStandby, stateRecovering, stateHealthy: + // close connections to other non-active servers whenever we have a new active server + defer s.closeAll() + case stateActive: + if len(s.connections) > len(srv.connections) { + // if there is a currently active server that has more connections than we do, + // close our connections and go to preferred instead + new_state = statePreferred + defer srv.closeAll() + } else { + // otherwise, close its connections and demote it to preferred + s.state = statePreferred + defer s.closeAll() + } + } } } - lb.currentServerAddress = lb.randomServers[0] - lb.nextServerIndex = 1 - return true + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return + } + + logrus.Infof("Server %s->%s from successful %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) } -// nextServer attempts to get the next server in the loadbalancer server list. -// If another goroutine has already updated the current server address to point at -// a different address than just failed, nothing is changed. Otherwise, a new server address -// is stored to the currentServerAddress field, and returned for use. -// This function must always be called by a goroutine that holds a read lock on the loadbalancer mutex. -func (lb *LoadBalancer) nextServer(failedServer string) (string, error) { - // note: these fields are not protected by the mutex, so we clamp the index value and update - // the index/current address using local variables, to avoid time-of-check vs time-of-use - // race conditions caused by goroutine A incrementing it in between the time goroutine B - // validates its value, and uses it as a list index. - currentServerAddress := lb.currentServerAddress - nextServerIndex := lb.nextServerIndex - - if len(lb.randomServers) == 0 { - return "", errors.New("No servers in load balancer proxy list") +// recordFailure records a failed check of a server, either via health-check or dial. +// The server's state is adjusted accordingly. +func (sl *serverList) recordFailure(srv *server, r reason) { + var new_state state + switch srv.state { + case stateUnchecked, stateRecovering: + if r == reasonDial { + // only demote from unchecked or recovering if a dial fails, health checks may + // continue to fail despite it being dialable. just leave it where it is + // and don't close any connections. + new_state = stateFailed + } + case stateHealthy, statePreferred, stateActive: + // should not have any connections when in any state other than active or + // recovering, but close them all anyway to force failover. + defer srv.closeAll() + new_state = stateFailed } - if len(lb.randomServers) == 1 { - return currentServerAddress, nil + + // no-op if state did not change + if new_state == stateInvalid { + return } - if failedServer != currentServerAddress { - return currentServerAddress, nil + + // sort the server list while holding the lock + sl.mutex.Lock() + defer sl.mutex.Unlock() + + // ensure some other routine didn't already make the transition + if srv.state == new_state { + return } - if nextServerIndex >= len(lb.randomServers) { - nextServerIndex = 0 + + logrus.Infof("Server %s->%s from failed %s", srv, new_state, r) + srv.state = new_state + srv.lastTransition = time.Now() + + slices.SortFunc(sl.servers, compareServers) +} + +// state is possible server health states, in increasing order of preference. +// The server list is kept sorted in descending order by this state value. +type state int + +const ( + stateInvalid state = iota + stateFailed // failed a health check or dial + stateStandby // reserved for use by default server if not in server list + stateUnchecked // just added, has not been health checked + stateRecovering // successfully health checked once, or dialed when failed + stateHealthy // normal state + statePreferred // recently transitioned from recovering; should be preferred as others may go down for maintenance + stateActive // currently active server +) + +func (s state) String() string { + switch s { + case stateInvalid: + return "INVALID" + case stateFailed: + return "FAILED" + case stateStandby: + return "STANDBY" + case stateUnchecked: + return "UNCHECKED" + case stateRecovering: + return "RECOVERING" + case stateHealthy: + return "HEALTHY" + case statePreferred: + return "PREFERRED" + case stateActive: + return "ACTIVE" + default: + return "UNKNOWN" } +} - currentServerAddress = lb.randomServers[nextServerIndex] - nextServerIndex++ +// reason specifies the reason for a successful or failed health report +type reason int - lb.currentServerAddress = currentServerAddress - lb.nextServerIndex = nextServerIndex +const ( + reasonDial reason = iota + reasonHealthCheck +) + +func (r reason) String() string { + switch r { + case reasonDial: + return "dial" + case reasonHealthCheck: + return "health check" + default: + return "unknown reason" + } +} + +// server tracks the connections to a server, so that they can be closed when the server is removed. +type server struct { + // This mutex protects access to the connections map. All direct access to the map should be protected by it. + mutex sync.Mutex + address string + isDefault bool + state state + lastTransition time.Time + healthCheck HealthCheckFunc + connections map[net.Conn]struct{} +} + +// newServer creates a new server, with a default health check +// and default/state fields appropriate for whether or not +// the server is a full server, or just a fallback default. +func newServer(address string, isDefault bool) *server { + state := stateUnchecked + if isDefault { + state = stateStandby + } + return &server{ + address: address, + isDefault: isDefault, + state: state, + lastTransition: time.Now(), + healthCheck: func() HealthCheckResult { return HealthCheckResultUnknown }, + connections: make(map[net.Conn]struct{}), + } +} - return currentServerAddress, nil +func (s *server) String() string { + format := "%s@%s" + if s.isDefault { + format += "*" + } + return fmt.Sprintf(format, s.address, s.state) } -// dialContext dials a new connection using the environment's proxy settings, and adds its wrapped connection to the map -func (s *server) dialContext(ctx context.Context, network, address string) (net.Conn, error) { - conn, err := defaultDialer.Dial(network, address) +// dialContext dials a new connection to the server using the environment's proxy settings, and adds its wrapped connection to the map +func (s *server) dialContext(ctx context.Context, network string) (net.Conn, error) { + if s.state == stateInvalid { + return nil, fmt.Errorf("server %s is stopping", s.address) + } + + conn, err := defaultDialer.Dial(network, s.address) if err != nil { return nil, err } @@ -174,26 +432,13 @@ func (s *server) dialContext(ctx context.Context, network, address string) (net. return wrappedConn, nil } -// proxyDialer creates a new proxy.Dialer that routes connections through the specified proxy. -func proxyDialer(proxyURL *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { - if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Create a new HTTP proxy dialer - httpProxyDialer := http_dialer.New(proxyURL, http_dialer.WithDialer(forward.(*net.Dialer))) - return httpProxyDialer, nil - } else if proxyURL.Scheme == "socks5" { - // For SOCKS5 proxies, use the proxy package's FromURL - return proxy.FromURL(proxyURL, forward) - } - return nil, fmt.Errorf("unsupported proxy scheme: %s", proxyURL.Scheme) -} - // closeAll closes all connections to the server, and removes their entries from the map func (s *server) closeAll() { s.mutex.Lock() defer s.mutex.Unlock() if l := len(s.connections); l > 0 { - logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s.address) + logrus.Infof("Closing %d connections to load balancer server %s", len(s.connections), s) for conn := range s.connections { // Close the connection in a goroutine so that we don't hold the lock while doing so. go conn.Close() @@ -201,6 +446,12 @@ func (s *server) closeAll() { } } +// serverConn wraps a net.Conn so that it can be removed from the server's connection map when closed. +type serverConn struct { + server *server + net.Conn +} + // Close removes the connection entry from the server's connection map, and // closes the wrapped connection. func (sc *serverConn) Close() error { @@ -211,73 +462,47 @@ func (sc *serverConn) Close() error { return sc.Conn.Close() } -// SetDefault sets the selected address as the default / fallback address -func (lb *LoadBalancer) SetDefault(serverAddress string) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) - // if the old default server is not currently in use, remove it from the server map - if server := lb.servers[lb.defaultServerAddress]; server != nil && !hasDefaultServer { - defer server.closeAll() - delete(lb.servers, lb.defaultServerAddress) - } - // if the new default server doesn't have an entry in the map, add one - but - // with a failing health check so that it is only used as a last resort. - if _, ok := lb.servers[serverAddress]; !ok { - lb.servers[serverAddress] = &server{ - address: serverAddress, - healthCheck: func() bool { return false }, - connections: make(map[net.Conn]struct{}), +// runHealthChecks periodically health-checks all servers and updates metrics +func (sl *serverList) runHealthChecks(ctx context.Context, serviceName string) { + wait.Until(func() { + for _, s := range sl.getServers() { + switch s.healthCheck() { + case HealthCheckResultOK: + sl.recordSuccess(s, reasonHealthCheck) + case HealthCheckResultFailed: + sl.recordFailure(s, reasonHealthCheck) + } + if s.state != stateInvalid { + loadbalancerState.WithLabelValues(serviceName, s.address).Set(float64(s.state)) + loadbalancerConnections.WithLabelValues(serviceName, s.address).Set(float64(len(s.connections))) + } } - } - - lb.defaultServerAddress = serverAddress - logrus.Infof("Updated load balancer %s default server address -> %s", lb.serviceName, serverAddress) + }, time.Second, ctx.Done()) + logrus.Debugf("Stopped health checking for load balancer %s", serviceName) } -// SetHealthCheck adds a health-check callback to an address, replacing the default no-op function. -func (lb *LoadBalancer) SetHealthCheck(address string, healthCheck func() bool) { - lb.mutex.Lock() - defer lb.mutex.Unlock() - - if server := lb.servers[address]; server != nil { - logrus.Debugf("Added health check for load balancer %s: %s", lb.serviceName, address) - server.healthCheck = healthCheck - } else { - logrus.Errorf("Failed to add health check for load balancer %s: no server found for %s", lb.serviceName, address) +// dialContext attemps to dial a connection to a server from the server list. +// Success or failure is recorded to ensure that server state is updated appropriately. +func (sl *serverList) dialContext(ctx context.Context, network, _ string) (net.Conn, error) { + for _, s := range sl.getServers() { + dialTime := time.Now() + conn, err := s.dialContext(ctx, network) + if err == nil { + sl.recordSuccess(s, reasonDial) + return conn, nil + } + logrus.Debugf("Dial error from server %s after %s: %s", s, time.Now().Sub(dialTime), err) + sl.recordFailure(s, reasonDial) } + return nil, errors.New("all servers failed") } -// runHealthChecks periodically health-checks all servers. Any servers that fail the health-check will have their -// connections closed, to force clients to switch over to a healthy server. -func (lb *LoadBalancer) runHealthChecks(ctx context.Context) { - previousStatus := map[string]bool{} - wait.Until(func() { - lb.mutex.RLock() - defer lb.mutex.RUnlock() - var healthyServerExists bool - for address, server := range lb.servers { - status := server.healthCheck() - healthyServerExists = healthyServerExists || status - if status == false && previousStatus[address] == true { - // Only close connections when the server transitions from healthy to unhealthy; - // we don't want to re-close all the connections every time as we might be ignoring - // health checks due to all servers being marked unhealthy. - defer server.closeAll() - } - previousStatus[address] = status - } - - // If there is at least one healthy server, and the default server is not in the server list, - // close all the connections to the default server so that clients reconnect and switch over - // to a preferred server. - hasDefaultServer := slices.Contains(lb.ServerAddresses, lb.defaultServerAddress) - if healthyServerExists && !hasDefaultServer { - if server, ok := lb.servers[lb.defaultServerAddress]; ok { - defer server.closeAll() - } - } - }, time.Second, ctx.Done()) - logrus.Debugf("Stopped health checking for load balancer %s", lb.serviceName) +// compareServers is a comparison function that can be used to sort the server list +// so that servers with a more preferred state, or higher number of connections, are ordered first. +func compareServers(a, b *server) int { + c := cmp.Compare(b.state, a.state) + if c == 0 { + return cmp.Compare(len(b.connections), len(a.connections)) + } + return c } diff --git a/pkg/agent/netpol/netpol.go b/pkg/agent/netpol/netpol.go index 5c892a668f36..a9f7a43f532e 100644 --- a/pkg/agent/netpol/netpol.go +++ b/pkg/agent/netpol/netpol.go @@ -26,12 +26,12 @@ import ( "github.com/coreos/go-iptables/iptables" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/metrics" + "github.com/k3s-io/k3s/pkg/util" "github.com/pkg/errors" "github.com/sirupsen/logrus" v1core "k8s.io/api/core/v1" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" ) func init() { @@ -57,7 +57,7 @@ func Run(ctx context.Context, nodeConfig *config.Node) error { return nil } - restConfig, err := clientcmd.BuildConfigFromFlags("", nodeConfig.AgentConfig.KubeConfigK3sController) + restConfig, err := util.GetRESTConfig(nodeConfig.AgentConfig.KubeConfigK3sController) if err != nil { return err } diff --git a/pkg/agent/proxy/apiproxy.go b/pkg/agent/proxy/apiproxy.go index e711623e467e..56d86a031366 100644 --- a/pkg/agent/proxy/apiproxy.go +++ b/pkg/agent/proxy/apiproxy.go @@ -22,7 +22,7 @@ type Proxy interface { SupervisorAddresses() []string APIServerURL() string IsAPIServerLBEnabled() bool - SetHealthCheck(address string, healthCheck func() bool) + SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) } // NewSupervisorProxy sets up a new proxy for retrieving supervisor and apiserver addresses. If @@ -52,7 +52,7 @@ func NewSupervisorProxy(ctx context.Context, lbEnabled bool, dataDir, supervisor return nil, err } p.supervisorLB = lb - p.supervisorURL = lb.LoadBalancerServerURL() + p.supervisorURL = lb.LocalURL() p.apiServerURL = p.supervisorURL } @@ -102,7 +102,7 @@ func (p *proxy) Update(addresses []string) { p.supervisorAddresses = supervisorAddresses } -func (p *proxy) SetHealthCheck(address string, healthCheck func() bool) { +func (p *proxy) SetHealthCheck(address string, healthCheck loadbalancer.HealthCheckFunc) { if p.supervisorLB != nil { p.supervisorLB.SetHealthCheck(address, healthCheck) } @@ -155,7 +155,7 @@ func (p *proxy) SetAPIServerPort(port int, isIPv6 bool) error { return err } p.apiServerLB = lb - p.apiServerURL = lb.LoadBalancerServerURL() + p.apiServerURL = lb.LocalURL() } else { p.apiServerURL = u.String() } diff --git a/pkg/agent/tunnel/tunnel.go b/pkg/agent/tunnel/tunnel.go index a5df415c7343..98094a8d02dc 100644 --- a/pkg/agent/tunnel/tunnel.go +++ b/pkg/agent/tunnel/tunnel.go @@ -7,14 +7,15 @@ import ( "fmt" "net" "os" - "reflect" "strconv" "sync" "time" "github.com/gorilla/websocket" agentconfig "github.com/k3s-io/k3s/pkg/agent/config" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/agent/proxy" + "github.com/k3s-io/k3s/pkg/clientaccess" daemonconfig "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/k3s-io/k3s/pkg/util" "github.com/k3s-io/k3s/pkg/version" @@ -26,12 +27,12 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" "k8s.io/client-go/tools/cache" - "k8s.io/client-go/tools/clientcmd" toolswatch "k8s.io/client-go/tools/watch" "k8s.io/kubernetes/pkg/cluster/ports" ) @@ -69,7 +70,7 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er return err } - nodeRestConfig, err := clientcmd.BuildConfigFromFlags("", config.AgentConfig.KubeConfigKubelet) + nodeRestConfig, err := util.GetRESTConfig(config.AgentConfig.KubeConfigKubelet) if err != nil { return err } @@ -138,17 +139,18 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er // connecting to. If that fails, fall back to querying the endpoints list from Kubernetes. This // fallback requires that the server we're joining be running an apiserver, but is the only safe // thing to do if its supervisor is down-level and can't provide us with an endpoint list. - addresses := agentconfig.APIServers(ctx, config, proxy) - logrus.Infof("Got apiserver addresses from supervisor: %v", addresses) - + addresses := agentconfig.WaitForAPIServers(ctx, config, proxy) if len(addresses) > 0 { + logrus.Infof("Got apiserver addresses from supervisor: %v", addresses) if localSupervisorDefault { proxy.SetSupervisorDefault(addresses[0]) } proxy.Update(addresses) } else { - if endpoint, _ := client.CoreV1().Endpoints(metav1.NamespaceDefault).Get(ctx, "kubernetes", metav1.GetOptions{}); endpoint != nil { - addresses = util.GetAddresses(endpoint) + if endpoint, err := client.CoreV1().Endpoints(metav1.NamespaceDefault).Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { + logrus.Errorf("Failed to get apiserver addresses from kubernetes endpoints: %v", err) + } else { + addresses := util.GetAddresses(endpoint) logrus.Infof("Got apiserver addresses from kubernetes endpoints: %v", addresses) if len(addresses) > 0 { proxy.Update(addresses) @@ -159,7 +161,7 @@ func Setup(ctx context.Context, config *daemonconfig.Node, proxy proxy.Proxy) er wg := &sync.WaitGroup{} - go tunnel.watchEndpoints(ctx, apiServerReady, wg, tlsConfig, proxy) + go tunnel.watchEndpoints(ctx, apiServerReady, wg, tlsConfig, config, proxy) wait := make(chan int, 1) go func() { @@ -302,23 +304,21 @@ func (a *agentTunnel) watchPods(ctx context.Context, apiServerReady <-chan struc // WatchEndpoints attempts to create tunnels to all supervisor addresses. Once the // apiserver is up, go into a watch loop, adding and removing tunnels as endpoints come // and go from the cluster. -func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan struct{}, wg *sync.WaitGroup, tlsConfig *tls.Config, proxy proxy.Proxy) { - // Attempt to connect to supervisors, storing their cancellation function for later when we - // need to disconnect. - disconnect := map[string]context.CancelFunc{} - for _, address := range proxy.SupervisorAddresses() { - if _, ok := disconnect[address]; !ok { - conn := a.connect(ctx, wg, address, tlsConfig) - disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) - } - } +func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan struct{}, wg *sync.WaitGroup, tlsConfig *tls.Config, node *daemonconfig.Node, proxy proxy.Proxy) { + syncProxyAddresses := a.getProxySyncer(ctx, wg, tlsConfig, proxy) + refreshFromSupervisor := getAPIServersRequester(node, proxy, syncProxyAddresses) <-apiServerReady + endpoints := a.client.CoreV1().Endpoints(metav1.NamespaceDefault) fieldSelector := fields.Set{metav1.ObjectNameField: "kubernetes"}.String() lw := &cache.ListWatch{ ListFunc: func(options metav1.ListOptions) (object runtime.Object, e error) { + // if we're being called to re-list, then likely there was an + // interruption to the apiserver connection and the listwatch is retrying + // its connection. This is a good suggestion that it might be necessary + // to refresh the apiserver address from the supervisor. + go refreshFromSupervisor(ctx) options.FieldSelector = fieldSelector return endpoints.List(ctx, options) }, @@ -364,38 +364,7 @@ func (a *agentTunnel) watchEndpoints(ctx context.Context, apiServerReady <-chan // goroutine that sleeps for a short period before checking for changes and updating // the proxy addresses. If another update occurs, the previous update operation // will be cancelled and a new one queued. - go func() { - select { - case <-time.After(endpointDebounceDelay): - case <-debounceCtx.Done(): - return - } - - newAddresses := util.GetAddresses(endpoint) - if reflect.DeepEqual(newAddresses, proxy.SupervisorAddresses()) { - return - } - proxy.Update(newAddresses) - - validEndpoint := map[string]bool{} - - for _, address := range proxy.SupervisorAddresses() { - validEndpoint[address] = true - if _, ok := disconnect[address]; !ok { - conn := a.connect(ctx, nil, address, tlsConfig) - disconnect[address] = conn.cancel - proxy.SetHealthCheck(address, conn.connected) - } - } - - for address, cancel := range disconnect { - if !validEndpoint[address] { - cancel() - delete(disconnect, address) - logrus.Infof("Stopped tunnel to %s", address) - } - } - }() + go syncProxyAddresses(debounceCtx, util.GetAddresses(endpoint)) } } } @@ -427,22 +396,20 @@ func (a *agentTunnel) authorized(ctx context.Context, proto, address string) boo } type agentConnection struct { - cancel context.CancelFunc - connected func() bool + cancel context.CancelFunc + healthCheck loadbalancer.HealthCheckFunc } // connect initiates a connection to the remotedialer server. Incoming dial requests from // the server will be checked by the authorizer function prior to being fulfilled. func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup, address string, tlsConfig *tls.Config) agentConnection { + var status loadbalancer.HealthCheckResult + wsURL := fmt.Sprintf("wss://%s/v1-"+version.Program+"/connect", address) ws := &websocket.Dialer{ TLSClientConfig: tlsConfig, } - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true - once := sync.Once{} if waitGroup != nil { waitGroup.Add(1) @@ -454,7 +421,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup } onConnect := func(_ context.Context, _ *remotedialer.Session) error { - connected = true + status = loadbalancer.HealthCheckResultOK logrus.WithField("url", wsURL).Info("Remotedialer connected to proxy") if waitGroup != nil { once.Do(waitGroup.Done) @@ -467,7 +434,7 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup for { // ConnectToProxy blocks until error or context cancellation err := remotedialer.ConnectToProxyWithDialer(ctx, wsURL, nil, auth, ws, a.dialContext, onConnect) - connected = false + status = loadbalancer.HealthCheckResultFailed if err != nil && !errors.Is(err, context.Canceled) { logrus.WithField("url", wsURL).WithError(err).Error("Remotedialer proxy error; reconnecting...") // wait between reconnection attempts to avoid hammering the server @@ -484,8 +451,10 @@ func (a *agentTunnel) connect(rootCtx context.Context, waitGroup *sync.WaitGroup }() return agentConnection{ - cancel: cancel, - connected: func() bool { return connected }, + cancel: cancel, + healthCheck: func() loadbalancer.HealthCheckResult { + return status + }, } } @@ -507,3 +476,84 @@ func (a *agentTunnel) dialContext(ctx context.Context, network, address string) } return defaultDialer.DialContext(ctx, network, address) } + +// proxySyncer is a common signature for functions that sync the proxy address list with a context +type proxySyncer func(ctx context.Context, addresses []string) + +// getProxySyncer returns a function that can be called to update the list of supervisors. +// This function is responsible for connecting to or disconnecting websocket tunnels, +// as well as updating the proxy loadbalancer server list. +func (a *agentTunnel) getProxySyncer(ctx context.Context, wg *sync.WaitGroup, tlsConfig *tls.Config, proxy proxy.Proxy) proxySyncer { + disconnect := map[string]context.CancelFunc{} + // Attempt to connect to supervisors, storing their cancellation function for later when we + // need to disconnect. + for _, address := range proxy.SupervisorAddresses() { + if _, ok := disconnect[address]; !ok { + conn := a.connect(ctx, wg, address, tlsConfig) + disconnect[address] = conn.cancel + proxy.SetHealthCheck(address, conn.healthCheck) + } + } + + // return a function that can be called to update the address list. + // servers will be connected to or disconnected from as necessary, + // and the proxy addresses updated. + return func(debounceCtx context.Context, addresses []string) { + select { + case <-time.After(endpointDebounceDelay): + case <-debounceCtx.Done(): + return + } + + // Compare list of supervisor addresses before and after syncing apiserver + // endpoints into the proxy to figure out which supervisors we need to connect to + // or disconnect from. Note that the addresses we were passed will not match + // the supervisor addresses if the supervisor and apiserver are on different ports - + // they must be round-tripped through proxy.Update before comparing. + curAddresses := sets.New(proxy.SupervisorAddresses()...) + proxy.Update(addresses) + newAddresses := sets.New(proxy.SupervisorAddresses()...) + + // add new servers + for address := range newAddresses.Difference(curAddresses) { + if _, ok := disconnect[address]; !ok { + conn := a.connect(ctx, nil, address, tlsConfig) + logrus.Infof("Started tunnel to %s", address) + disconnect[address] = conn.cancel + proxy.SetHealthCheck(address, conn.healthCheck) + } + } + + // remove old servers + for address := range curAddresses.Difference(newAddresses) { + if cancel, ok := disconnect[address]; ok { + cancel() + delete(disconnect, address) + logrus.Infof("Stopped tunnel to %s", address) + } + } + } +} + +// getAPIServersRequester returns a function that can be called to update the +// proxy apiserver endpoints with addresses retrieved from the supervisor. +func getAPIServersRequester(node *daemonconfig.Node, proxy proxy.Proxy, syncProxyAddresses proxySyncer) func(ctx context.Context) { + var info *clientaccess.Info + return func(ctx context.Context) { + if info == nil { + var err error + withCert := clientaccess.WithClientCertificate(node.AgentConfig.ClientKubeletCert, node.AgentConfig.ClientKubeletKey) + info, err = clientaccess.ParseAndValidateToken(proxy.SupervisorURL(), node.Token, withCert) + if err != nil { + logrus.Warnf("Failed to validate server token: %v", err) + return + } + } + + if addresses, err := agentconfig.GetAPIServers(ctx, info); err != nil { + logrus.Warnf("Failed to get apiserver addresses from supervisor: %v", err) + } else { + syncProxyAddresses(ctx, addresses) + } + } +} diff --git a/pkg/cli/cmds/server.go b/pkg/cli/cmds/server.go index c398fc14c1bb..e2eee6ae022d 100644 --- a/pkg/cli/cmds/server.go +++ b/pkg/cli/cmds/server.go @@ -188,6 +188,27 @@ var ServerFlags = []cli.Flag{ Value: 6443, Destination: &ServerConfig.HTTPSPort, }, + &cli.IntFlag{ + Name: "supervisor-port", + EnvVar: version.ProgramUpper + "_SUPERVISOR_PORT", + Usage: "(experimental) Supervisor listen port override", + Hidden: true, + Destination: &ServerConfig.SupervisorPort, + }, + &cli.IntFlag{ + Name: "apiserver-port", + EnvVar: version.ProgramUpper + "_APISERVER_PORT", + Usage: "(experimental) apiserver internal listen port override", + Hidden: true, + Destination: &ServerConfig.APIServerPort, + }, + &cli.StringFlag{ + Name: "apiserver-bind-address", + EnvVar: version.ProgramUpper + "_APISERVER_BIND_ADDRESS", + Usage: "(experimental) apiserver internal bind address override", + Hidden: true, + Destination: &ServerConfig.APIServerBindAddress, + }, &cli.StringFlag{ Name: "advertise-address", Usage: "(listener) IPv4/IPv6 address that apiserver uses to advertise to members of the cluster (default: node-external-ip/node-ip)", @@ -195,7 +216,7 @@ var ServerFlags = []cli.Flag{ }, &cli.IntFlag{ Name: "advertise-port", - Usage: "(listener) Port that apiserver uses to advertise to members of the cluster (default: listen-port)", + Usage: "(listener) Port that apiserver uses to advertise to members of the cluster (default: https-listen-port)", Destination: &ServerConfig.AdvertisePort, }, &cli.StringSliceFlag{ diff --git a/pkg/cli/token/token.go b/pkg/cli/token/token.go index e16038fea5b6..9d514d7b5286 100644 --- a/pkg/cli/token/token.go +++ b/pkg/cli/token/token.go @@ -24,7 +24,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/fields" "k8s.io/apimachinery/pkg/util/duration" - "k8s.io/client-go/tools/clientcmd" bootstrapapi "k8s.io/cluster-bootstrap/token/api" bootstraputil "k8s.io/cluster-bootstrap/token/util" "k8s.io/utils/ptr" @@ -48,7 +47,7 @@ func create(app *cli.Context, cfg *cmds.Token) error { return err } - restConfig, err := clientcmd.BuildConfigFromFlags("", cfg.Kubeconfig) + restConfig, err := util.GetRESTConfig(cfg.Kubeconfig) if err != nil { return err } diff --git a/pkg/cluster/address_controller.go b/pkg/cluster/address_controller.go index 780942d0d3ae..bb73a20deac4 100644 --- a/pkg/cluster/address_controller.go +++ b/pkg/cluster/address_controller.go @@ -8,20 +8,17 @@ import ( controllerv1 "github.com/rancher/wrangler/v3/pkg/generated/controllers/core/v1" "github.com/sirupsen/logrus" v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/sets" ) func registerAddressHandlers(ctx context.Context, c *Cluster) { nodes := c.config.Runtime.Core.Core().V1().Node() a := &addressesHandler{ nodeController: nodes, - allowed: map[string]bool{}, + allowed: sets.New(c.config.SANs...), } - for _, cn := range c.config.SANs { - a.allowed[cn] = true - } - - logrus.Infof("Starting dynamiclistener CN filter node controller") + logrus.Infof("Starting dynamiclistener CN filter node controller with SANs: %v", c.config.SANs) nodes.OnChange(ctx, "server-cn-filter", a.sync) c.cnFilterFunc = a.filterCN } @@ -30,40 +27,30 @@ type addressesHandler struct { sync.RWMutex nodeController controllerv1.NodeController - allowed map[string]bool + allowed sets.Set[string] } // filterCN filters a list of potential server CNs (hostnames or IPs), removing any which do not correspond to // valid cluster servers (control-plane or etcd), or an address explicitly added via the tls-san option. func (a *addressesHandler) filterCN(cns ...string) []string { - if !a.nodeController.Informer().HasSynced() { + if len(cns) == 0 || !a.nodeController.Informer().HasSynced() { return cns } a.RLock() defer a.RUnlock() - filteredCNs := make([]string, 0, len(cns)) - for _, cn := range cns { - if a.allowed[cn] { - filteredCNs = append(filteredCNs, cn) - } else { - logrus.Debugf("CN filter controller rejecting certificate CN: %s", cn) - } - } - return filteredCNs + return a.allowed.Intersection(sets.New(cns...)).UnsortedList() } // sync updates the allowed address list to include addresses for the node func (a *addressesHandler) sync(key string, node *v1.Node) (*v1.Node, error) { - if node != nil { - if node.Labels[util.ControlPlaneRoleLabelKey] != "" || node.Labels[util.ETCDRoleLabelKey] != "" { - a.Lock() - defer a.Unlock() + if node != nil && (node.Labels[util.ControlPlaneRoleLabelKey] != "" || node.Labels[util.ETCDRoleLabelKey] != "") { + a.Lock() + defer a.Unlock() - for _, address := range node.Status.Addresses { - a.allowed[address.String()] = true - } + for _, address := range node.Status.Addresses { + a.allowed.Insert(address.String()) } } return node, nil diff --git a/pkg/daemons/config/types.go b/pkg/daemons/config/types.go index f6336a2ba251..ffb67d702e02 100644 --- a/pkg/daemons/config/types.go +++ b/pkg/daemons/config/types.go @@ -17,6 +17,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" utilnet "k8s.io/apimachinery/pkg/util/net" "k8s.io/apiserver/pkg/authentication/authenticator" + "k8s.io/client-go/kubernetes" "k8s.io/client-go/tools/record" utilsnet "k8s.io/utils/net" ) @@ -369,6 +370,7 @@ type ControlRuntime struct { ClientETCDCert string ClientETCDKey string + K8s kubernetes.Interface K3s *k3s.Factory Core *core.Factory Event record.EventRecorder diff --git a/pkg/daemons/executor/embed.go b/pkg/daemons/executor/embed.go index 0553da84e3e0..7e69f956e84b 100644 --- a/pkg/daemons/executor/embed.go +++ b/pkg/daemons/executor/embed.go @@ -28,7 +28,6 @@ import ( "k8s.io/apiserver/pkg/authentication/authenticator" typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/cache" - "k8s.io/client-go/tools/clientcmd" toolswatch "k8s.io/client-go/tools/watch" cloudprovider "k8s.io/cloud-provider" cloudproviderapi "k8s.io/cloud-provider/api" @@ -269,7 +268,7 @@ func (e *Embedded) Docker(ctx context.Context, cfg *daemonconfig.Node) error { // waitForUntaintedNode watches nodes, waiting to find one not tainted as // uninitialized by the external cloud provider. func waitForUntaintedNode(ctx context.Context, kubeConfig string) error { - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeConfig) + restConfig, err := util.GetRESTConfig(kubeConfig) if err != nil { return err } diff --git a/pkg/etcd/etcdproxy.go b/pkg/etcd/etcdproxy.go index 55918850b3ff..156834440c08 100644 --- a/pkg/etcd/etcdproxy.go +++ b/pkg/etcd/etcdproxy.go @@ -6,21 +6,16 @@ import ( "fmt" "net" "net/http" - "net/url" "strconv" "time" "github.com/k3s-io/k3s/pkg/agent/loadbalancer" - "github.com/pkg/errors" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/util/wait" ) type Proxy interface { Update(addresses []string) - ETCDURL() string - ETCDAddresses() []string - ETCDServerURL() string } var httpClient = &http.Client{ @@ -34,51 +29,33 @@ var httpClient = &http.Client{ // NewETCDProxy initializes a new proxy structure that contain a load balancer // which listens on port 2379 and proxy between etcd cluster members func NewETCDProxy(ctx context.Context, supervisorPort int, dataDir, etcdURL string, isIPv6 bool) (Proxy, error) { - u, err := url.Parse(etcdURL) - if err != nil { - return nil, errors.Wrap(err, "failed to parse etcd client URL") - } - - e := &etcdproxy{ - dataDir: dataDir, - initialETCDURL: etcdURL, - etcdURL: etcdURL, - supervisorPort: supervisorPort, - disconnect: map[string]context.CancelFunc{}, - } - lb, err := loadbalancer.New(ctx, dataDir, loadbalancer.ETCDServerServiceName, etcdURL, 2379, isIPv6) if err != nil { return nil, err } - e.etcdLB = lb - e.etcdLBURL = lb.LoadBalancerServerURL() - - e.fallbackETCDAddress = u.Host - e.etcdPort = u.Port() - return e, nil + return &etcdproxy{ + supervisorPort: supervisorPort, + etcdLB: lb, + disconnect: map[string]context.CancelFunc{}, + }, nil } type etcdproxy struct { - dataDir string - etcdLBURL string - - supervisorPort int - initialETCDURL string - etcdURL string - etcdPort string - fallbackETCDAddress string - etcdAddresses []string - etcdLB *loadbalancer.LoadBalancer - disconnect map[string]context.CancelFunc + supervisorPort int + etcdLB *loadbalancer.LoadBalancer + disconnect map[string]context.CancelFunc } func (e *etcdproxy) Update(addresses []string) { + if e.etcdLB == nil { + return + } + e.etcdLB.Update(addresses) validEndpoint := map[string]bool{} - for _, address := range e.etcdLB.ServerAddresses { + for _, address := range e.etcdLB.ServerAddresses() { validEndpoint[address] = true if _, ok := e.disconnect[address]; !ok { ctx, cancel := context.WithCancel(context.Background()) @@ -95,27 +72,10 @@ func (e *etcdproxy) Update(addresses []string) { } } -func (e *etcdproxy) ETCDURL() string { - return e.etcdURL -} - -func (e *etcdproxy) ETCDAddresses() []string { - if len(e.etcdAddresses) > 0 { - return e.etcdAddresses - } - return []string{e.fallbackETCDAddress} -} - -func (e *etcdproxy) ETCDServerURL() string { - return e.etcdURL -} - // start a polling routine that makes periodic requests to the etcd node's supervisor port. // If the request fails, the node is marked unhealthy. -func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() bool { - // Assume that the connection to the server will succeed, to avoid failing health checks while attempting to connect. - // If we cannot connect, connected will be set to false when the initial connection attempt fails. - connected := true +func (e etcdproxy) createHealthCheck(ctx context.Context, address string) loadbalancer.HealthCheckFunc { + var status loadbalancer.HealthCheckResult host, _, _ := net.SplitHostPort(address) url := fmt.Sprintf("https://%s/ping", net.JoinHostPort(host, strconv.Itoa(e.supervisorPort))) @@ -131,13 +91,17 @@ func (e etcdproxy) createHealthCheck(ctx context.Context, address string) func() } if err != nil || statusCode != http.StatusOK { logrus.Debugf("Health check %s failed: %v (StatusCode: %d)", address, err, statusCode) - connected = false + status = loadbalancer.HealthCheckResultFailed } else { - connected = true + status = loadbalancer.HealthCheckResultOK } }, 5*time.Second, 1.0, true) - return func() bool { - return connected + return func() loadbalancer.HealthCheckResult { + // Reset the status to unknown on reading, until next time it is checked. + // This avoids having a health check result alter the server state between active checks. + s := status + status = loadbalancer.HealthCheckResultUnknown + return s } } diff --git a/pkg/metrics/metrics.go b/pkg/metrics/metrics.go index a769e6a38418..eccb4abb0bbc 100644 --- a/pkg/metrics/metrics.go +++ b/pkg/metrics/metrics.go @@ -6,6 +6,7 @@ import ( "github.com/gorilla/mux" "github.com/k3s-io/k3s/pkg/agent/https" + "github.com/k3s-io/k3s/pkg/agent/loadbalancer" "github.com/k3s-io/k3s/pkg/daemons/config" "github.com/prometheus/client_golang/prometheus/promhttp" lassometrics "github.com/rancher/lasso/pkg/metrics" @@ -32,6 +33,8 @@ var DefaultMetrics = &Config{ func init() { // ensure that lasso exposes metrics through the same registry used by Kubernetes components lassometrics.MustRegister(DefaultRegisterer) + // same for loadbalancer metrics + loadbalancer.MustRegister(DefaultRegisterer) } // Config holds fields for the metrics listener diff --git a/pkg/secretsencrypt/config.go b/pkg/secretsencrypt/config.go index aae309d8fbad..7d2f2e4a725b 100644 --- a/pkg/secretsencrypt/config.go +++ b/pkg/secretsencrypt/config.go @@ -15,7 +15,6 @@ import ( "github.com/k3s-io/k3s/pkg/version" "github.com/prometheus/common/expfmt" corev1 "k8s.io/api/core/v1" - "k8s.io/client-go/tools/clientcmd" "github.com/k3s-io/k3s/pkg/generated/clientset/versioned/scheme" "github.com/sirupsen/logrus" @@ -237,7 +236,7 @@ func GetEncryptionConfigMetrics(runtime *config.ControlRuntime, initialMetrics b var unixUpdateTime int64 var reloadSuccessCounter int64 var lastFailure string - restConfig, err := clientcmd.BuildConfigFromFlags("", runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(runtime.KubeConfigSupervisor) if err != nil { return 0, 0, err } diff --git a/pkg/server/context.go b/pkg/server/context.go index ac6724820ee3..fb4928e8f1ad 100644 --- a/pkg/server/context.go +++ b/pkg/server/context.go @@ -19,7 +19,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/record" ) @@ -43,7 +42,7 @@ func NewContext(ctx context.Context, config *Config, forServer bool) (*Context, if forServer { cfg = config.ControlConfig.Runtime.KubeConfigSupervisor } - restConfig, err := clientcmd.BuildConfigFromFlags("", cfg) + restConfig, err := util.GetRESTConfig(cfg) if err != nil { return nil, err } diff --git a/pkg/server/router.go b/pkg/server/router.go index ec60d5f3d9c9..fca554027880 100644 --- a/pkg/server/router.go +++ b/pkg/server/router.go @@ -19,6 +19,7 @@ import ( "github.com/k3s-io/k3s/pkg/bootstrap" "github.com/k3s-io/k3s/pkg/cli/cmds" "github.com/k3s-io/k3s/pkg/daemons/config" + "github.com/k3s-io/k3s/pkg/etcd" "github.com/k3s-io/k3s/pkg/nodepassword" "github.com/k3s-io/k3s/pkg/server/auth" "github.com/k3s-io/k3s/pkg/util" @@ -31,9 +32,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/json" + "k8s.io/apimachinery/pkg/util/sets" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/apiserver/pkg/authentication/user" "k8s.io/apiserver/pkg/endpoints/request" + typedcorev1 "k8s.io/client-go/kubernetes/typed/core/v1" bootstrapapi "k8s.io/cluster-bootstrap/token/api" "k8s.io/kubernetes/pkg/auth/nodeidentifier" ) @@ -304,21 +307,15 @@ func fileHandler(fileName ...string) http.Handler { }) } +// apiserversHandler returns a list of apiserver addresses. +// It attempts to merge results from both the apiserver and directly from etcd, +// in case we are recovering from an apiserver outage that rendered the endpoint list unavailable. func apiserversHandler(server *config.Control) http.Handler { - var endpointsClient coreclient.EndpointsClient + collectAddresses := getAddressCollector(server) return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { - var endpoints []string - if endpointsClient == nil { - if server.Runtime.Core != nil { - endpointsClient = server.Runtime.Core.Core().V1().Endpoints() - } - } - if endpointsClient != nil { - if endpoint, _ := endpointsClient.Get("default", "kubernetes", metav1.GetOptions{}); endpoint != nil { - endpoints = util.GetAddresses(endpoint) - } - } - + ctx, cancel := context.WithTimeout(req.Context(), 5*time.Second) + defer cancel() + endpoints := collectAddresses(ctx) resp.Header().Set("content-type", "application/json") if err := json.NewEncoder(resp).Encode(endpoints); err != nil { util.SendError(errors.Wrap(err, "failed to encode apiserver endpoints"), resp, req, http.StatusInternalServerError) @@ -524,3 +521,75 @@ func ensureSecret(ctx context.Context, config *Config, node *nodeInfo) { return false, nil }) } + +// addressGetter is a common signature for functions that return an address channel +type addressGetter func(ctx context.Context) <-chan []string + +// kubernetesGetter returns a function that returns a channel that can be read to get apiserver addresses from kubernetes endpoints +func kubernetesGetter(server *config.Control) addressGetter { + var endpointsClient typedcorev1.EndpointsInterface + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if endpointsClient == nil { + if server.Runtime.K8s != nil { + endpointsClient = server.Runtime.K8s.CoreV1().Endpoints(metav1.NamespaceDefault) + } + } + if endpointsClient != nil { + if endpoint, err := endpointsClient.Get(ctx, "kubernetes", metav1.GetOptions{}); err != nil { + logrus.Debugf("Failed to get apiserver addresses from kubernetes: %v", err) + } else { + ch <- util.GetAddresses(endpoint) + } + } + close(ch) + }() + return ch + } +} + +// etcdGetter returns a function that returns a channel that can be read to get apiserver addresses from etcd +func etcdGetter(server *config.Control) addressGetter { + return func(ctx context.Context) <-chan []string { + ch := make(chan []string, 1) + go func() { + if addresses, err := etcd.GetAPIServerURLsFromETCD(ctx, server); err != nil { + logrus.Debugf("Failed to get apiserver addresses from etcd: %v", err) + } else { + ch <- addresses + } + close(ch) + }() + return ch + } +} + +// getAddressCollector returns a function that can be called to return +// apiserver addresses from both kubernetes and etcd +func getAddressCollector(server *config.Control) func(ctx context.Context) []string { + getFromKubernetes := kubernetesGetter(server) + getFromEtcd := etcdGetter(server) + + // read from both kubernetes and etcd in parallel, returning the collected results + return func(ctx context.Context) []string { + a := sets.Set[string]{} + r := []string{} + k8sCh := getFromKubernetes(ctx) + k8sOk := true + etcdCh := getFromEtcd(ctx) + etcdOk := true + + for k8sOk || etcdOk { + select { + case r, k8sOk = <-k8sCh: + a.Insert(r...) + case r, etcdOk = <-etcdCh: + a.Insert(r...) + case <-ctx.Done(): + return a.UnsortedList() + } + } + return a.UnsortedList() + } +} diff --git a/pkg/server/secrets-encrypt.go b/pkg/server/secrets-encrypt.go index 256c98ce1003..a3759d9617c4 100644 --- a/pkg/server/secrets-encrypt.go +++ b/pkg/server/secrets-encrypt.go @@ -27,8 +27,6 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" apiserverconfigv1 "k8s.io/apiserver/pkg/apis/apiserver/v1" - "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/pager" "k8s.io/client-go/util/retry" "k8s.io/utils/ptr" @@ -395,18 +393,7 @@ func reencryptAndRemoveKey(ctx context.Context, server *config.Control, skip boo } func updateSecrets(ctx context.Context, server *config.Control, nodeName string) error { - restConfig, err := clientcmd.BuildConfigFromFlags("", server.Runtime.KubeConfigSupervisor) - if err != nil { - return err - } - // For secrets we need a much higher QPS than default - restConfig.QPS = secretsencrypt.SecretQPS - restConfig.Burst = secretsencrypt.SecretBurst - k8s, err := kubernetes.NewForConfig(restConfig) - if err != nil { - return err - } - + k8s := server.Runtime.K8s nodeRef := &corev1.ObjectReference{ Kind: "Node", Name: nodeName, diff --git a/pkg/server/server.go b/pkg/server/server.go index 8c6f40c4330c..a8c1e0d470f7 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -35,7 +35,6 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" clientset "k8s.io/client-go/kubernetes" - "k8s.io/client-go/tools/clientcmd" ) func ResolveDataDir(dataDir string) (string, error) { @@ -113,6 +112,7 @@ func runControllers(ctx context.Context, config *Config) error { controlConfig.Runtime.NodePasswdFile); err != nil { logrus.Warn(errors.Wrap(err, "error migrating node-password file")) } + controlConfig.Runtime.K8s = sc.K8s controlConfig.Runtime.K3s = sc.K3s controlConfig.Runtime.Event = sc.Event controlConfig.Runtime.Core = sc.Core @@ -208,7 +208,7 @@ func coreControllers(ctx context.Context, sc *Context, config *Config) error { } if !config.ControlConfig.DisableHelmController { - restConfig, err := clientcmd.BuildConfigFromFlags("", config.ControlConfig.Runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(config.ControlConfig.Runtime.KubeConfigSupervisor) if err != nil { return err } @@ -285,7 +285,7 @@ func stageFiles(ctx context.Context, sc *Context, controlConfig *config.Control) return err } - restConfig, err := clientcmd.BuildConfigFromFlags("", controlConfig.Runtime.KubeConfigSupervisor) + restConfig, err := util.GetRESTConfig(controlConfig.Runtime.KubeConfigSupervisor) if err != nil { return err } diff --git a/pkg/util/api.go b/pkg/util/api.go index 5ce53c49ba48..4df9ad73a945 100644 --- a/pkg/util/api.go +++ b/pkg/util/api.go @@ -23,7 +23,6 @@ import ( authorizationv1client "k8s.io/client-go/kubernetes/typed/authorization/v1" coregetter "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/rest" - "k8s.io/client-go/tools/clientcmd" "k8s.io/client-go/tools/record" ) @@ -58,7 +57,7 @@ func GetAddresses(endpoint *v1.Endpoints) []string { // readyz endpoint instead of the deprecated healthz endpoint, and supports context. func WaitForAPIServerReady(ctx context.Context, kubeconfigPath string, timeout time.Duration) error { var lastErr error - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath) + restConfig, err := GetRESTConfig(kubeconfigPath) if err != nil { return err } @@ -112,7 +111,7 @@ type genericAccessReviewRequest func(context.Context) (*authorizationv1.SubjectA // the access would be allowed. func WaitForRBACReady(ctx context.Context, kubeconfigPath string, timeout time.Duration, ra authorizationv1.ResourceAttributes, user string, groups ...string) error { var lastErr error - restConfig, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath) + restConfig, err := GetRESTConfig(kubeconfigPath) if err != nil { return err } diff --git a/pkg/util/client.go b/pkg/util/client.go index 561a5cbc0817..a7ca9fe26b6d 100644 --- a/pkg/util/client.go +++ b/pkg/util/client.go @@ -5,12 +5,15 @@ import ( "os" "runtime" "strings" + "time" "github.com/k3s-io/k3s/pkg/datadir" "github.com/k3s-io/k3s/pkg/version" + "github.com/rancher/wrangler/v3/pkg/ratelimit" "github.com/sirupsen/logrus" "k8s.io/apimachinery/pkg/apis/meta/v1/validation" clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" "k8s.io/client-go/tools/clientcmd" ) @@ -28,7 +31,7 @@ func GetKubeConfigPath(file string) string { // GetClientSet creates a Kubernetes client from the kubeconfig at the provided path. func GetClientSet(file string) (clientset.Interface, error) { - restConfig, err := clientcmd.BuildConfigFromFlags("", file) + restConfig, err := GetRESTConfig(file) if err != nil { return nil, err } @@ -36,6 +39,18 @@ func GetClientSet(file string) (clientset.Interface, error) { return clientset.NewForConfig(restConfig) } +// GetRESTConfig returns a REST config with default timeouts and ratelimitsi cribbed from wrangler defaults. +// ref: https://github.com/rancher/wrangler/blob/v3.0.0/pkg/clients/clients.go#L184-L190 +func GetRESTConfig(file string) (*rest.Config, error) { + restConfig, err := clientcmd.BuildConfigFromFlags("", file) + if err != nil { + return nil, err + } + restConfig.Timeout = 15 * time.Minute + restConfig.RateLimiter = ratelimit.None + return restConfig, nil +} + // GetUserAgent builds a complete UserAgent string for a given controller, including the node name if possible. func GetUserAgent(controllerName string) string { nodeName := os.Getenv("NODE_NAME") diff --git a/tests/e2e/dualstack/dualstack_test.go b/tests/e2e/dualstack/dualstack_test.go index c9612f9b7142..9262af922cea 100644 --- a/tests/e2e/dualstack/dualstack_test.go +++ b/tests/e2e/dualstack/dualstack_test.go @@ -195,7 +195,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/embeddedmirror/embeddedmirror_test.go b/tests/e2e/embeddedmirror/embeddedmirror_test.go index 7188b552b988..089fb465277b 100644 --- a/tests/e2e/embeddedmirror/embeddedmirror_test.go +++ b/tests/e2e/embeddedmirror/embeddedmirror_test.go @@ -146,7 +146,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/externalip/externalip_test.go b/tests/e2e/externalip/externalip_test.go index 524bb8340276..9d2150991924 100644 --- a/tests/e2e/externalip/externalip_test.go +++ b/tests/e2e/externalip/externalip_test.go @@ -165,7 +165,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/privateregistry/privateregistry_test.go b/tests/e2e/privateregistry/privateregistry_test.go index 856f49b596c6..fe25a94e2181 100644 --- a/tests/e2e/privateregistry/privateregistry_test.go +++ b/tests/e2e/privateregistry/privateregistry_test.go @@ -149,8 +149,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/rootless/rootless_test.go b/tests/e2e/rootless/rootless_test.go index 361778c72db7..4a205934e3d5 100644 --- a/tests/e2e/rootless/rootless_test.go +++ b/tests/e2e/rootless/rootless_test.go @@ -167,7 +167,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, serverNodeNames)) + } else { Expect(e2e.GetCoverageReport(serverNodeNames)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/rotateca/rotateca_test.go b/tests/e2e/rotateca/rotateca_test.go index c43ab4d10899..3a6f2b0ca14f 100644 --- a/tests/e2e/rotateca/rotateca_test.go +++ b/tests/e2e/rotateca/rotateca_test.go @@ -138,7 +138,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/s3/s3_test.go b/tests/e2e/s3/s3_test.go index fc3be6a5fde4..b61824525934 100644 --- a/tests/e2e/s3/s3_test.go +++ b/tests/e2e/s3/s3_test.go @@ -175,8 +175,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/secretsencryption/secretsencryption_test.go b/tests/e2e/secretsencryption/secretsencryption_test.go index 187dcedba2fc..763e2f0ba381 100644 --- a/tests/e2e/secretsencryption/secretsencryption_test.go +++ b/tests/e2e/secretsencryption/secretsencryption_test.go @@ -221,7 +221,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, serverNodeNames)) + } else { Expect(e2e.GetCoverageReport(serverNodeNames)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/snapshotrestore/snapshotrestore_test.go b/tests/e2e/snapshotrestore/snapshotrestore_test.go index dc47907f78c7..f9ca105cb24b 100644 --- a/tests/e2e/snapshotrestore/snapshotrestore_test.go +++ b/tests/e2e/snapshotrestore/snapshotrestore_test.go @@ -95,7 +95,7 @@ var _ = Describe("Verify snapshots and cluster restores work", Ordered, func() { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-clusterip --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) g.Expect(err).NotTo(HaveOccurred()) - g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: "+cmd+" result: "+res) + g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: %q result: %s", cmd, res) }, "240s", "5s").Should(Succeed()) }) @@ -317,7 +317,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/splitserver/splitserver_test.go b/tests/e2e/splitserver/splitserver_test.go index c78520d67b41..642dbc1592e3 100644 --- a/tests/e2e/splitserver/splitserver_test.go +++ b/tests/e2e/splitserver/splitserver_test.go @@ -283,9 +283,11 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { - allNodes := append(cpNodeNames, etcdNodeNames...) - allNodes = append(allNodes, agentNodeNames...) + allNodes := append(cpNodeNames, etcdNodeNames...) + allNodes = append(allNodes, agentNodeNames...) + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, allNodes)) + } else { Expect(e2e.GetCoverageReport(allNodes)).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/startup/startup_test.go b/tests/e2e/startup/startup_test.go index c926164fac14..3c7dd13627ea 100644 --- a/tests/e2e/startup/startup_test.go +++ b/tests/e2e/startup/startup_test.go @@ -71,6 +71,12 @@ func KillK3sCluster(nodes []string) error { if _, err := e2e.RunCmdOnNode("k3s-killall.sh", node); err != nil { return err } + if _, err := e2e.RunCmdOnNode("journalctl --flush --sync --rotate --vacuum-size=1", node); err != nil { + return err + } + if _, err := e2e.RunCmdOnNode("rm -rf /etc/rancher/k3s/config.yaml.d", node); err != nil { + return err + } if strings.Contains(node, "server") { if _, err := e2e.RunCmdOnNode("rm -rf /var/lib/rancher/k3s/server/db", node); err != nil { return err @@ -93,6 +99,83 @@ var _ = BeforeSuite(func() { }) var _ = Describe("Various Startup Configurations", Ordered, func() { + Context("Verify dedicated supervisor port", func() { + It("Starts K3s with no issues", func() { + for _, node := range agentNodeNames { + cmd := "mkdir -p /etc/rancher/k3s/config.yaml.d; grep -F server: /etc/rancher/k3s/config.yaml | sed s/6443/9345/ > /tmp/99-server.yaml; sudo mv /tmp/99-server.yaml /etc/rancher/k3s/config.yaml.d/" + res, err := e2e.RunCmdOnNode(cmd, node) + By("checking command results: " + res) + Expect(err).NotTo(HaveOccurred()) + } + supervisorPortYAML := "supervisor-port: 9345\napiserver-port: 6443\napiserver-bind-address: 0.0.0.0\ndisable: traefik\nnode-taint: node-role.kubernetes.io/control-plane:NoExecute" + err := StartK3sCluster(append(serverNodeNames, agentNodeNames...), supervisorPortYAML, "") + Expect(err).NotTo(HaveOccurred(), e2e.GetVagrantLog(err)) + + fmt.Println("CLUSTER CONFIG") + fmt.Println("OS:", *nodeOS) + fmt.Println("Server Nodes:", serverNodeNames) + fmt.Println("Agent Nodes:", agentNodeNames) + kubeConfigFile, err = e2e.GenKubeConfigFile(serverNodeNames[0]) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Checks node and pod status", func() { + fmt.Printf("\nFetching node status\n") + Eventually(func(g Gomega) { + nodes, err := e2e.ParseNodes(kubeConfigFile, false) + g.Expect(err).NotTo(HaveOccurred()) + for _, node := range nodes { + g.Expect(node.Status).Should(Equal("Ready")) + } + }, "360s", "5s").Should(Succeed()) + _, _ = e2e.ParseNodes(kubeConfigFile, true) + + fmt.Printf("\nFetching pods status\n") + Eventually(func(g Gomega) { + pods, err := e2e.ParsePods(kubeConfigFile, false) + g.Expect(err).NotTo(HaveOccurred()) + for _, pod := range pods { + if strings.Contains(pod.Name, "helm-install") { + g.Expect(pod.Status).Should(Equal("Completed"), pod.Name) + } else { + g.Expect(pod.Status).Should(Equal("Running"), pod.Name) + } + } + }, "360s", "5s").Should(Succeed()) + _, _ = e2e.ParsePods(kubeConfigFile, true) + }) + + It("Returns pod metrics", func() { + cmd := "kubectl top pod -A" + Eventually(func() error { + _, err := e2e.RunCommand(cmd) + return err + }, "600s", "5s").Should(Succeed()) + }) + + It("Returns node metrics", func() { + cmd := "kubectl top node" + _, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Runs an interactive command a pod", func() { + cmd := "kubectl run busybox --rm -it --restart=Never --image=rancher/mirrored-library-busybox:1.36.1 -- uname -a" + _, err := e2e.RunCmdOnNode(cmd, serverNodeNames[0]) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Collects logs from a pod", func() { + cmd := "kubectl logs -n kube-system -l k8s-app=metrics-server -c metrics-server" + _, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred()) + }) + + It("Kills the cluster", func() { + err := KillK3sCluster(append(serverNodeNames, agentNodeNames...)) + Expect(err).NotTo(HaveOccurred()) + }) + }) Context("Verify CRI-Dockerd :", func() { It("Starts K3s with no issues", func() { dockerYAML := "docker: true" @@ -310,7 +393,10 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("config", e2e.GetConfig(append(serverNodeNames, agentNodeNames...))) + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go b/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go index 53128947a234..dec419e176c4 100644 --- a/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go +++ b/tests/e2e/svcpoliciesandfirewall/svcpoliciesandfirewall_test.go @@ -128,7 +128,7 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered Eventually(func(g Gomega) { externalIPs, _ := e2e.FetchExternalIPs(kubeConfigFile, lbSvcExt) g.Expect(externalIPs).To(HaveLen(1), "more than 1 exernalIP found") - g.Expect(externalIPs[0]).To(Equal(serverNodeIP),"external IP does not match servernodeIP") + g.Expect(externalIPs[0]).To(Equal(serverNodeIP), "external IP does not match servernodeIP") }, "25s", "5s").Should(Succeed()) }) @@ -154,7 +154,6 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered return e2e.RunCommand(cmd) }, "25s", "5s").ShouldNot(ContainSubstring("10.42")) - // Verify connectivity to the other nodeIP does not work because of external traffic policy=local for _, externalIP := range lbSvcExternalIPs { if externalIP == lbSvcExtExternalIPs[0] { @@ -250,7 +249,7 @@ var _ = Describe("Verify Services Traffic policies and firewall config", Ordered )) // Check the non working command fails because of internal traffic policy=local - Eventually(func() (bool) { + Eventually(func() bool { _, err := e2e.RunCommand(nonWorkingCmd) if err != nil && strings.Contains(err.Error(), "exit status") { // Treat exit status as a successful condition @@ -348,7 +347,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/tailscale/tailscale_test.go b/tests/e2e/tailscale/tailscale_test.go index 3def1ac41ab5..449840e4f990 100644 --- a/tests/e2e/tailscale/tailscale_test.go +++ b/tests/e2e/tailscale/tailscale_test.go @@ -118,7 +118,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/testutils.go b/tests/e2e/testutils.go index 950a3c8af896..2d2cb12071b0 100644 --- a/tests/e2e/testutils.go +++ b/tests/e2e/testutils.go @@ -48,7 +48,7 @@ type NodeError struct { type SvcExternalIP struct { IP string `json:"ip"` - ipMode string `json:"ipMode"` + IPMode string `json:"ipMode"` } type ObjIP struct { @@ -364,6 +364,32 @@ func GetJournalLogs(node string) (string, error) { return RunCmdOnNode(cmd, node) } +func TailJournalLogs(lines int, nodes []string) string { + logs := &strings.Builder{} + for _, node := range nodes { + cmd := fmt.Sprintf("journalctl -u k3s* --no-pager --lines=%d", lines) + if l, err := RunCmdOnNode(cmd, node); err != nil { + fmt.Fprintf(logs, "** failed to read journald log for node %s ***\n%v\n", node, err) + } else { + fmt.Fprintf(logs, "** journald log for node %s ***\n%s\n", node, l) + } + } + return logs.String() +} + +func GetConfig(nodes []string) string { + config := &strings.Builder{} + for _, node := range nodes { + cmd := "tar -Pc /etc/rancher/k3s/ | tar -vxPO" + if c, err := RunCmdOnNode(cmd, node); err != nil { + fmt.Fprintf(config, "** failed to get config for node %s ***\n%v\n", node, err) + } else { + fmt.Fprintf(config, "** config for node %s ***\n%s\n", node, c) + } + } + return config.String() +} + // GetVagrantLog returns the logs of on vagrant commands that initialize the nodes and provision K3s on each node. // It also attempts to fetch the systemctl logs of K3s on nodes where the k3s.service failed. func GetVagrantLog(cErr error) string { diff --git a/tests/e2e/token/token_test.go b/tests/e2e/token/token_test.go index bd0cc38a1fc8..3b3c011d6ae7 100644 --- a/tests/e2e/token/token_test.go +++ b/tests/e2e/token/token_test.go @@ -202,7 +202,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/upgradecluster/upgradecluster_test.go b/tests/e2e/upgradecluster/upgradecluster_test.go index 18bd1cbee7b1..fab93a6bbd89 100644 --- a/tests/e2e/upgradecluster/upgradecluster_test.go +++ b/tests/e2e/upgradecluster/upgradecluster_test.go @@ -215,14 +215,13 @@ var _ = Describe("Verify Upgrade", Ordered, func() { }, "420s", "2s").Should(Succeed()) cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec volume-test -- sh -c 'echo local-path-test > /data/test'" - _, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) + res, err := e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) fmt.Println("Data stored in pvc: local-path-test") cmd = "kubectl delete pod volume-test --kubeconfig=" + kubeConfigFile - res, err := e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) - fmt.Println(res) + res, err = e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) _, err = e2e.DeployWorkload("local-path-provisioner.yaml", kubeConfigFile, *hardened) Expect(err).NotTo(HaveOccurred(), "local-path-provisioner manifest not deployed") @@ -245,7 +244,7 @@ var _ = Describe("Verify Upgrade", Ordered, func() { Eventually(func() (string, error) { cmd := "kubectl exec volume-test --kubeconfig=" + kubeConfigFile + " -- cat /data/test" return e2e.RunCommand(cmd) - }, "180s", "2s").Should(ContainSubstring("local-path-test")) + }, "180s", "2s").Should(ContainSubstring("local-path-test"), "Failed to retrieve data from pvc") }) It("Upgrades with no issues", func() { @@ -385,7 +384,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/validatecluster/validatecluster_test.go b/tests/e2e/validatecluster/validatecluster_test.go index accae34dadf8..2c4807cce98d 100644 --- a/tests/e2e/validatecluster/validatecluster_test.go +++ b/tests/e2e/validatecluster/validatecluster_test.go @@ -95,7 +95,7 @@ var _ = Describe("Verify Create", Ordered, func() { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-clusterip --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) Expect(err).NotTo(HaveOccurred()) - g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: "+cmd+" result: "+res) + g.Expect(res).Should((ContainSubstring("test-clusterip")), "failed cmd: %q result: %s", cmd, res) }, "240s", "5s").Should(Succeed()) clusterip, _ := e2e.FetchClusterIP(kubeConfigFile, "nginx-clusterip-svc", false) @@ -130,7 +130,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-nodeport")) }, "240s", "5s").Should(Succeed()) } @@ -150,14 +150,14 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pods -o=name -l k8s-app=nginx-app-loadbalancer --field-selector=status.phase=Running --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-loadbalancer")) }, "240s", "5s").Should(Succeed()) Eventually(func(g Gomega) { cmd = "curl -L --insecure http://" + ip + ":" + port + "/name.html" res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-loadbalancer")) }, "240s", "5s").Should(Succeed()) } @@ -174,7 +174,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("test-ingress")) }, "240s", "5s").Should(Succeed()) } @@ -204,7 +204,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pods dnsutils --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("dnsutils")) }, "420s", "2s").Should(Succeed()) @@ -212,7 +212,7 @@ var _ = Describe("Verify Create", Ordered, func() { cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec -i -t dnsutils -- nslookup kubernetes.default" res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("kubernetes.default.svc.cluster.local")) }, "420s", "2s").Should(Succeed()) }) @@ -224,7 +224,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pvc local-path-pvc --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("local-path-pvc")) g.Expect(res).Should(ContainSubstring("Bound")) }, "420s", "2s").Should(Succeed()) @@ -232,18 +232,18 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pod volume-test --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("volume-test")) g.Expect(res).Should(ContainSubstring("Running")) }, "420s", "2s").Should(Succeed()) cmd := "kubectl --kubeconfig=" + kubeConfigFile + " exec volume-test -- sh -c 'echo local-path-test > /data/test'" - _, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred()) + res, err = e2e.RunCommand(cmd) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) cmd = "kubectl delete pod volume-test --kubeconfig=" + kubeConfigFile res, err = e2e.RunCommand(cmd) - Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) _, err = e2e.DeployWorkload("local-path-provisioner.yaml", kubeConfigFile, *hardened) Expect(err).NotTo(HaveOccurred(), "local-path-provisioner manifest not deployed") @@ -257,7 +257,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl get pod volume-test --kubeconfig=" + kubeConfigFile res, err := e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) g.Expect(res).Should(ContainSubstring("volume-test")) g.Expect(res).Should(ContainSubstring("Running")) @@ -266,7 +266,7 @@ var _ = Describe("Verify Create", Ordered, func() { Eventually(func(g Gomega) { cmd := "kubectl exec volume-test --kubeconfig=" + kubeConfigFile + " -- cat /data/test" res, err = e2e.RunCommand(cmd) - g.Expect(err).NotTo(HaveOccurred(), "failed cmd: "+cmd+" result: "+res) + g.Expect(err).NotTo(HaveOccurred(), "failed cmd: %q result: %s", cmd, res) fmt.Println("Data after re-creation", res) g.Expect(res).Should(ContainSubstring("local-path-test")) }, "180s", "2s").Should(Succeed()) @@ -381,7 +381,9 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if !failed { + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) + } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) } if !failed || *ci { diff --git a/tests/e2e/wasm/wasm_test.go b/tests/e2e/wasm/wasm_test.go index 1e887a086a29..7fa216088b35 100644 --- a/tests/e2e/wasm/wasm_test.go +++ b/tests/e2e/wasm/wasm_test.go @@ -135,10 +135,12 @@ var _ = AfterEach(func() { }) var _ = AfterSuite(func() { - if failed && !*ci { - fmt.Println("FAILED!") + if failed { + AddReportEntry("journald-logs", e2e.TailJournalLogs(1000, append(serverNodeNames, agentNodeNames...))) } else { Expect(e2e.GetCoverageReport(append(serverNodeNames, agentNodeNames...))).To(Succeed()) + } + if !failed || *ci { Expect(e2e.DestroyCluster()).To(Succeed()) Expect(os.Remove(kubeConfigFile)).To(Succeed()) }