Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 82 additions & 43 deletions internal/cmd/envoy/shutdown_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@ import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"os/signal"
"regexp"
"strconv"
"syscall"
"time"

Expand Down Expand Up @@ -137,7 +140,7 @@ func Shutdown(drainTimeout, minDrainDuration time.Duration, exitAtConnections in
for {
elapsedTime := time.Since(startTime)

conn, err := getTotalConnections()
conn, err := getTotalConnections(bootstrap.EnvoyAdminPort)
if err != nil {
logger.Error(err, "error getting total connections")
}
Expand Down Expand Up @@ -169,54 +172,90 @@ func Shutdown(drainTimeout, minDrainDuration time.Duration, exitAtConnections in

// postEnvoyAdminAPI sends a POST request to the Envoy admin API
func postEnvoyAdminAPI(path string) error {
if resp, err := http.Post(fmt.Sprintf("http://%s:%d/%s",
"localhost", bootstrap.EnvoyAdminPort, path), "application/json", nil); err != nil {
resp, err := http.Post(fmt.Sprintf("http://%s:%d/%s",
"localhost", bootstrap.EnvoyAdminPort, path), "application/json", nil)
if err != nil {
return err
} else {
defer resp.Body.Close()
}
if resp == nil {
return errors.New("unexcepted nil response from Envoy admin API")
}
defer func() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can there be a case where resp is nil

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAIK, resp shouldn't be nil when there's no error.

_ = resp.Body.Close()
}()

if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected response status: %s", resp.Status)
}
return nil
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected response status: %s", resp.Status)
}
return nil
}

func getTotalConnections(port int) (*int, error) {
return getDownstreamCXActive(port)
}

// Define struct to decode JSON response into; expecting a single stat in the response in the format:
// {"stats":[{"name":"server.total_connections","value":123}]}
type envoyStatsResponse struct {
Stats []struct {
Name string
Value int
}
}

// getTotalConnections retrieves the total number of open connections from Envoy's server.total_connections stat
func getTotalConnections() (*int, error) {
// Send request to Envoy admin API to retrieve server.total_connections stat
if resp, err := http.Get(fmt.Sprintf("http://%s:%d//stats?filter=^server\\.total_connections$&format=json",
"localhost", bootstrap.EnvoyAdminPort)); err != nil {
func getStatsFromEnvoyStatsEndpoint(port int, statFilter string) (*envoyStatsResponse, error) {
resp, err := http.Get(fmt.Sprintf("http://%s//stats?filter=%s&format=json",
net.JoinHostPort("localhost", strconv.Itoa(port)), statFilter))
if err != nil {
return nil, err
}

defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected response status: %s", resp.Status)
}

r := &envoyStatsResponse{}
// Decode JSON response into struct
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return nil, err
} else {
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("unexpected response status: %s", resp.Status)
} else {
// Define struct to decode JSON response into; expecting a single stat in the response in the format:
// {"stats":[{"name":"server.total_connections","value":123}]}
var r *struct {
Stats []struct {
Name string
Value int
}
}

// Decode JSON response into struct
if err := json.NewDecoder(resp.Body).Decode(&r); err != nil {
return nil, err
}

// Defensive check for empty stats
if len(r.Stats) == 0 {
return nil, fmt.Errorf("no stats found")
}

// Log and return total connections
c := r.Stats[0].Value
logger.Info(fmt.Sprintf("total connections: %d", c))
return &c, nil
}

// Defensive check for empty stats
if len(r.Stats) == 0 {
return nil, fmt.Errorf("no stats found")
}

return r, nil
}

// getDownstreamCXActive retrieves the total number of open connections from Envoy's listener downstream_cx_active stat
func getDownstreamCXActive(port int) (*int, error) {
// Send request to Envoy admin API to retrieve listener.\.$.downstream_cx_active stat
statFilter := "^listener\\..*\\.downstream_cx_active$"
r, err := getStatsFromEnvoyStatsEndpoint(port, statFilter)
if err != nil {
return nil, fmt.Errorf("error getting listener downstream_cx_active stat: %w", err)
}

totalConnection := filterDownstreamCXActive(r)
logger.Info(fmt.Sprintf("total downstream connections: %d", *totalConnection))
return totalConnection, nil
}

// skipConnectionRE is a regex to match connection stats to be excluded from total connections count
// e.g. admin, ready and stat listener and stats from worker thread
var skipConnectionRE = regexp.MustCompile(`admin|19001|19003|worker`)

func filterDownstreamCXActive(r *envoyStatsResponse) *int {
totalConnection := 0
for _, stat := range r.Stats {
if excluded := skipConnectionRE.MatchString(stat.Name); !excluded {
totalConnection += stat.Value
}
}

return &totalConnection
}
240 changes: 240 additions & 0 deletions internal/cmd/envoy/shutdown_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
// Copyright Envoy Gateway Authors
// SPDX-License-Identifier: Apache-2.0
// The full text of the Apache license is available in the LICENSE file at
// the root of the repo.

package envoy

import (
"errors"
"fmt"
"io"
"net"
"net/http"
"strconv"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
"k8s.io/utils/ptr"
)

// setupFakeEnvoyStats set up an HTTP server return content
func setupFakeEnvoyStats(t *testing.T, content string) *http.Server {
l, err := net.Listen("tcp", ":0") //nolint: gosec
require.NoError(t, err)
require.NoError(t, l.Close())
mux := http.NewServeMux()
mux.HandleFunc("/", func(writer http.ResponseWriter, _ *http.Request) {
writer.Header().Set("Content-Type", "application/json")
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(content))
})

addr := l.Addr().String()
s := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: time.Second,
}
t.Logf("start to listen at %s ", addr)
go func() {
if err := s.ListenAndServe(); err != nil {
fmt.Println("fail to listen: ", err)
}
}()

return s
}

func TestGetTotalConnections(t *testing.T) {
cases := []struct {
name string
input string

expectedError error
expectedCount *int
}{
{
name: "downstream_cx_active",
input: `{
"stats": [
{
"name": "listener.0.0.0.0_8000.downstream_cx_active",
"value": 1
},
{
"name": "listener.0.0.0.0_8000.worker_0.downstream_cx_active",
"value": 1
},
{
"name": "listener.0.0.0.0_8000.worker_1.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_2.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_3.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_4.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_5.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_6.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_7.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_8.downstream_cx_active",
"value": 0
},
{
"name": "listener.0.0.0.0_8000.worker_9.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_0.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_1.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_2.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_3.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_4.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_5.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_6.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_7.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_8.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8080.worker_9.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_0.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_1.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_2.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_3.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_4.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_5.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_6.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_7.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_8.downstream_cx_active",
"value": 0
},
{
"name": "listener.127.0.0.1_8081.worker_9.downstream_cx_active",
"value": 0
},
{
"name": "listener.admin.downstream_cx_active",
"value": 2
},
{
"name": "listener.admin.main_thread.downstream_cx_active",
"value": 2
}
]
}`,
expectedCount: ptr.To(1),
},
{
name: "invalid",
input: `{"stats":[{"name":"listener.0.0.0.0_8000.downstream_cx_active","value":1]}`,
expectedError: errors.New("error getting listener downstream_cx_active stat"),
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
s := setupFakeEnvoyStats(t, tc.input)
_, port, err := net.SplitHostPort(s.Addr)
require.NoError(t, err)

p, err := strconv.Atoi(port)
require.NoError(t, err)
defer func() {
_ = s.Close()
}()
reader := strings.NewReader(tc.input)
rc := io.NopCloser(reader)
defer func() {
_ = rc.Close()
}()

gotCount, gotError := getTotalConnections(p)
if tc.expectedError != nil {
require.ErrorContains(t, gotError, tc.expectedError.Error())
return
}
require.NoError(t, gotError)
require.Equal(t, tc.expectedCount, gotCount)
})
}
}