diff --git a/internal/cmd/envoy/shutdown_manager.go b/internal/cmd/envoy/shutdown_manager.go index f1c82dae79..cd56853392 100644 --- a/internal/cmd/envoy/shutdown_manager.go +++ b/internal/cmd/envoy/shutdown_manager.go @@ -10,9 +10,12 @@ import ( "encoding/json" "errors" "fmt" + "net" "net/http" "os" "os/signal" + "regexp" + "strconv" "syscall" "time" @@ -137,7 +140,7 @@ func Shutdown(drainTimeout, minDrainDuration time.Duration, exitAtConnections in for { elapsedTime := time.Since(startTime) - conn, err := getTotalConnections() + conn, err := getTotalConnections(bootstrap.EnvoyAdminPort) if err != nil { logger.Error(err, "error getting total connections") } @@ -169,54 +172,90 @@ func Shutdown(drainTimeout, minDrainDuration time.Duration, exitAtConnections in // postEnvoyAdminAPI sends a POST request to the Envoy admin API func postEnvoyAdminAPI(path string) error { - if resp, err := http.Post(fmt.Sprintf("http://%s:%d/%s", - "localhost", bootstrap.EnvoyAdminPort, path), "application/json", nil); err != nil { + resp, err := http.Post(fmt.Sprintf("http://%s:%d/%s", + "localhost", bootstrap.EnvoyAdminPort, path), "application/json", nil) + if err != nil { return err - } else { - defer resp.Body.Close() + } + if resp == nil { + return errors.New("unexcepted nil response from Envoy admin API") + } + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected response status: %s", resp.Status) - } - return nil + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected response status: %s", resp.Status) + } + return nil +} + +func getTotalConnections(port int) (*int, error) { + return getDownstreamCXActive(port) +} + +// Define struct to decode JSON response into; expecting a single stat in the response in the format: +// {"stats":[{"name":"server.total_connections","value":123}]} +type envoyStatsResponse struct { + Stats []struct { + Name string + Value int } } -// getTotalConnections retrieves the total number of open connections from Envoy's server.total_connections stat -func getTotalConnections() (*int, error) { - // Send request to Envoy admin API to retrieve server.total_connections stat - if resp, err := http.Get(fmt.Sprintf("http://%s:%d//stats?filter=^server\\.total_connections$&format=json", - "localhost", bootstrap.EnvoyAdminPort)); err != nil { +func getStatsFromEnvoyStatsEndpoint(port int, statFilter string) (*envoyStatsResponse, error) { + resp, err := http.Get(fmt.Sprintf("http://%s//stats?filter=%s&format=json", + net.JoinHostPort("localhost", strconv.Itoa(port)), statFilter)) + if err != nil { + return nil, err + } + + defer func() { + _ = resp.Body.Close() + }() + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("unexpected response status: %s", resp.Status) + } + + r := &envoyStatsResponse{} + // Decode JSON response into struct + if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { return nil, err - } else { - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("unexpected response status: %s", resp.Status) - } else { - // Define struct to decode JSON response into; expecting a single stat in the response in the format: - // {"stats":[{"name":"server.total_connections","value":123}]} - var r *struct { - Stats []struct { - Name string - Value int - } - } - - // Decode JSON response into struct - if err := json.NewDecoder(resp.Body).Decode(&r); err != nil { - return nil, err - } - - // Defensive check for empty stats - if len(r.Stats) == 0 { - return nil, fmt.Errorf("no stats found") - } - - // Log and return total connections - c := r.Stats[0].Value - logger.Info(fmt.Sprintf("total connections: %d", c)) - return &c, nil + } + + // Defensive check for empty stats + if len(r.Stats) == 0 { + return nil, fmt.Errorf("no stats found") + } + + return r, nil +} + +// getDownstreamCXActive retrieves the total number of open connections from Envoy's listener downstream_cx_active stat +func getDownstreamCXActive(port int) (*int, error) { + // Send request to Envoy admin API to retrieve listener.\.$.downstream_cx_active stat + statFilter := "^listener\\..*\\.downstream_cx_active$" + r, err := getStatsFromEnvoyStatsEndpoint(port, statFilter) + if err != nil { + return nil, fmt.Errorf("error getting listener downstream_cx_active stat: %w", err) + } + + totalConnection := filterDownstreamCXActive(r) + logger.Info(fmt.Sprintf("total downstream connections: %d", *totalConnection)) + return totalConnection, nil +} + +// skipConnectionRE is a regex to match connection stats to be excluded from total connections count +// e.g. admin, ready and stat listener and stats from worker thread +var skipConnectionRE = regexp.MustCompile(`admin|19001|19003|worker`) + +func filterDownstreamCXActive(r *envoyStatsResponse) *int { + totalConnection := 0 + for _, stat := range r.Stats { + if excluded := skipConnectionRE.MatchString(stat.Name); !excluded { + totalConnection += stat.Value } } + + return &totalConnection } diff --git a/internal/cmd/envoy/shutdown_manager_test.go b/internal/cmd/envoy/shutdown_manager_test.go new file mode 100644 index 0000000000..16b904696f --- /dev/null +++ b/internal/cmd/envoy/shutdown_manager_test.go @@ -0,0 +1,240 @@ +// Copyright Envoy Gateway Authors +// SPDX-License-Identifier: Apache-2.0 +// The full text of the Apache license is available in the LICENSE file at +// the root of the repo. + +package envoy + +import ( + "errors" + "fmt" + "io" + "net" + "net/http" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "k8s.io/utils/ptr" +) + +// setupFakeEnvoyStats set up an HTTP server return content +func setupFakeEnvoyStats(t *testing.T, content string) *http.Server { + l, err := net.Listen("tcp", ":0") //nolint: gosec + require.NoError(t, err) + require.NoError(t, l.Close()) + mux := http.NewServeMux() + mux.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) { + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + _, _ = writer.Write([]byte(content)) + }) + + addr := l.Addr().String() + s := &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: time.Second, + } + t.Logf("start to listen at %s ", addr) + go func() { + if err := s.ListenAndServe(); err != nil { + fmt.Println("fail to listen: ", err) + } + }() + + return s +} + +func TestGetTotalConnections(t *testing.T) { + cases := []struct { + name string + input string + + expectedError error + expectedCount *int + }{ + { + name: "downstream_cx_active", + input: `{ + "stats": [ + { + "name": "listener.0.0.0.0_8000.downstream_cx_active", + "value": 1 + }, + { + "name": "listener.0.0.0.0_8000.worker_0.downstream_cx_active", + "value": 1 + }, + { + "name": "listener.0.0.0.0_8000.worker_1.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_2.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_3.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_4.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_5.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_6.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_7.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_8.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.0.0.0.0_8000.worker_9.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_0.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_1.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_2.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_3.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_4.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_5.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_6.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_7.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_8.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8080.worker_9.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_0.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_1.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_2.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_3.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_4.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_5.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_6.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_7.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_8.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.127.0.0.1_8081.worker_9.downstream_cx_active", + "value": 0 + }, + { + "name": "listener.admin.downstream_cx_active", + "value": 2 + }, + { + "name": "listener.admin.main_thread.downstream_cx_active", + "value": 2 + } + ] +}`, + expectedCount: ptr.To(1), + }, + { + name: "invalid", + input: `{"stats":[{"name":"listener.0.0.0.0_8000.downstream_cx_active","value":1]}`, + expectedError: errors.New("error getting listener downstream_cx_active stat"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + s := setupFakeEnvoyStats(t, tc.input) + _, port, err := net.SplitHostPort(s.Addr) + require.NoError(t, err) + + p, err := strconv.Atoi(port) + require.NoError(t, err) + defer func() { + _ = s.Close() + }() + reader := strings.NewReader(tc.input) + rc := io.NopCloser(reader) + defer func() { + _ = rc.Close() + }() + + gotCount, gotError := getTotalConnections(p) + if tc.expectedError != nil { + require.ErrorContains(t, gotError, tc.expectedError.Error()) + return + } + require.NoError(t, gotError) + require.Equal(t, tc.expectedCount, gotCount) + }) + } +}