Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve status API detection and validation #447

Merged
merged 2 commits into from
Aug 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 70 additions & 22 deletions sdk/config_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@ package sdk
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net"
"net/http"
Expand Down Expand Up @@ -669,41 +671,46 @@ func AddAuxfileToNginxConfig(

func parseAddressesFromServerDirective(parent *crossplane.Directive) []string {
addresses := []string{}
hosts := []string{}
port := "80"

for _, dir := range parent.Block {
address := "127.0.0.1"
hostname := "127.0.0.1"

switch dir.Directive {
case "listen":
host, listenPort, err := net.SplitHostPort(dir.Args[0])
if err == nil {
if host == "*" || host == "" {
address = "127.0.0.1"
hostname = "127.0.0.1"
} else if host == "::" || host == "::1" {
address = "[::1]"
hostname = "[::1]"
} else {
address = host
hostname = host
}
port = listenPort
} else {
if isPort(dir.Args[0]) {
port = dir.Args[0]
} else {
address = dir.Args[0]
hostname = dir.Args[0]
}
}
addresses = append(addresses, fmt.Sprintf("%s:%s", address, port))
hosts = append(hosts, hostname)
case "server_name":
if dir.Args[0] == "_" {
// default server
continue
}
address = dir.Args[0]
addresses = append(addresses, fmt.Sprintf("%s:%s", address, port))
hostname = dir.Args[0]
hosts = append(hosts, hostname)
}
}

for _, host := range hosts {
addresses = append(addresses, fmt.Sprintf("%s:%s", host, port))
}

return addresses
}

Expand All @@ -729,15 +736,15 @@ func statusAPICallback(parent *crossplane.Directive, current *crossplane.Directi
plusUrls := getUrlsForLocationDirective(parent, current, plusAPIDirective)

for _, url := range plusUrls {
if pingStatusAPIEndpoint(url) {
if pingNginxPlusApiEndpoint(url) {
log.Debugf("api at %q found", url)
return url
}
log.Debugf("api at %q is not reachable", url)
}

for _, url := range ossUrls {
if pingStatusAPIEndpoint(url) {
if pingStubStatusApiEndpoint(url) {
log.Debugf("stub_status at %q found", url)
return url
}
Expand All @@ -747,16 +754,6 @@ func statusAPICallback(parent *crossplane.Directive, current *crossplane.Directi
return ""
}

// pingStatusAPIEndpoint ensures the statusAPI is reachable from the agent
func pingStatusAPIEndpoint(statusAPI string) bool {
client := http.Client{Timeout: 50 * time.Millisecond}

if _, err := client.Head(statusAPI); err != nil {
return false
}
return true
}

// Deprecated: use either GetStubStatusApiUrl or GetNginxPlusApiUrl
func GetStatusApiInfoWithIgnoreDirectives(confFile string, ignoreDirectives []string) (statusApi string, err error) {
payload, err := crossplane.Parse(confFile,
Expand Down Expand Up @@ -834,7 +831,7 @@ func stubStatusApiCallback(parent *crossplane.Directive, current *crossplane.Dir
urls := getUrlsForLocationDirective(parent, current, stubStatusAPIDirective)

for _, url := range urls {
if pingStatusAPIEndpoint(url) {
if pingStubStatusApiEndpoint(url) {
log.Debugf("stub_status at %q found", url)
return url
}
Expand All @@ -848,7 +845,7 @@ func nginxPlusApiCallback(parent *crossplane.Directive, current *crossplane.Dire
urls := getUrlsForLocationDirective(parent, current, plusAPIDirective)

for _, url := range urls {
if pingStatusAPIEndpoint(url) {
if pingNginxPlusApiEndpoint(url) {
log.Debugf("plus API at %q found", url)
return url
}
Expand All @@ -858,6 +855,57 @@ func nginxPlusApiCallback(parent *crossplane.Directive, current *crossplane.Dire
return ""
}

func pingStubStatusApiEndpoint(statusAPI string) bool {
client := http.Client{Timeout: 50 * time.Millisecond}
resp, err := client.Get(statusAPI)
if err != nil {
return false
}

if resp.StatusCode != 200 {
return false
}

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return false
}

// Expecting API to return data like this:
//
// Active connections: 2
// server accepts handled requests
// 18 18 3266
// Reading: 0 Writing: 1 Waiting: 1
body := string(bodyBytes)
return strings.Contains(body, "Active connections") && strings.Contains(body, "server accepts handled requests")
}

func pingNginxPlusApiEndpoint(statusAPI string) bool {
client := http.Client{Timeout: 50 * time.Millisecond}
resp, err := client.Get(statusAPI)
if err != nil {
return false
}

if resp.StatusCode != 200 {
return false
}

bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return false
}

// Expecting API to return the api versions in an array like this:
//
// [1,2,3,4,5,6,7,8]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[1,2,3,4,5,6,7,8,9] since R30
I'd say give it a more generic comment like expecting API to return a positive integer

var responseBody []int
err = json.Unmarshal(bodyBytes, &responseBody)

return err == nil
}

func getUrlsForLocationDirective(parent *crossplane.Directive, current *crossplane.Directive, locationDirectiveName string) []string {
var urls []string
// process from the location block
Expand Down
129 changes: 126 additions & 3 deletions sdk/config_helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,12 +779,16 @@ func TestGetStatusApiInfo(t *testing.T) {

server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.String() == "/privateapi" {
data := []byte("api OK")
data := []byte("[1,2,3,4,5,6,7,8]")
_, err := rw.Write(data)
assert.Nil(t, err)
} else if req.URL.String() == "/stub_status" {
rw.WriteHeader(http.StatusInternalServerError)
data := []byte("stub_status OK")
data := []byte(`
Active connections: 2
server accepts handled requests
18 18 3266
Reading: 0 Writing: 1 Waiting: 1
`)
_, err := rw.Write(data)
assert.Nil(t, err)
}
Expand Down Expand Up @@ -1164,6 +1168,22 @@ server {
allow 127.0.0.1;
deny all;
}
}
`,
},
{
plus: []string{
"http://127.0.0.1:49151/api",
"http://127.0.0.1:49151/api",
},
conf: `
server {
server_name 127.0.0.1;
listen 127.0.0.1:49151;
access_log off;
location /api {
api;
}
}
`,
},
Expand Down Expand Up @@ -1784,3 +1804,106 @@ func TestGetAppProtectPolicyAndSecurityLogFiles(t *testing.T) {
})
}
}

func TestPingNginxPlusApiEndpoint(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.String() == "/good_api" {
data := []byte("[1,2,3,4,5,6,7,8]")
_, err := rw.Write(data)
assert.Nil(t, err)
} else if req.URL.String() == "/invalid_body_api" {
data := []byte("Invalid")
_, err := rw.Write(data)
assert.Nil(t, err)
} else {
rw.WriteHeader(http.StatusInternalServerError)
data := []byte("")
_, err := rw.Write(data)
assert.Nil(t, err)
}
}))
defer server.Close()

testCases := []struct {
name string
endpoint string
expected bool
}{
{
name: "valid API",
endpoint: "/good_api",
expected: true,
},
{
name: "invalid response status code",
endpoint: "/bad_api",
expected: false,
},
{
name: "invalid response body",
endpoint: "/invalid_body_api",
expected: false,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := pingNginxPlusApiEndpoint(fmt.Sprintf("%s%s", server.URL, testCase.endpoint))
assert.Equal(t, testCase.expected, result)
})
}
}

func TestPingStubStatusApiEndpoint(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
if req.URL.String() == "/good_api" {
data := []byte(`
Active connections: 2
server accepts handled requests
18 18 3266
Reading: 0 Writing: 1 Waiting: 1
`)
_, err := rw.Write(data)
assert.Nil(t, err)
} else if req.URL.String() == "/invalid_body_api" {
data := []byte("Invalid")
_, err := rw.Write(data)
assert.Nil(t, err)
} else {
rw.WriteHeader(http.StatusInternalServerError)
data := []byte("")
_, err := rw.Write(data)
assert.Nil(t, err)
}
}))
defer server.Close()

testCases := []struct {
name string
endpoint string
expected bool
}{
{
name: "valid API",
endpoint: "/good_api",
expected: true,
},
{
name: "invalid response status code",
endpoint: "/bad_api",
expected: false,
},
{
name: "invalid response body",
endpoint: "/invalid_body_api",
expected: false,
},
}

for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
result := pingStubStatusApiEndpoint(fmt.Sprintf("%s%s", server.URL, testCase.endpoint))
assert.Equal(t, testCase.expected, result)
})
}
}
Loading