Skip to content

Commit 63566f0

Browse files
RanVakninskmcgrail
andauthored
Update BuildAuthToken to validate endpoint contains a port (#1837)
* validated that the right side of the colon has to be an string representation of an integer * fixed linter error * Add changelog description Co-authored-by: Sean McGrail <[email protected]>
1 parent b011f04 commit 63566f0

File tree

3 files changed

+72
-10
lines changed

3 files changed

+72
-10
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"id": "aaba2642-a8f6-4293-a3ad-7dd5bc0e9ef7",
3+
"type": "feature",
4+
"description": "Updated `BuildAuthToken` to validate the provided endpoint contains a port.",
5+
"modules": [
6+
"feature/rds/auth"
7+
]
8+
}

feature/rds/auth/connect.go

+29
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"net/http"
7+
"strconv"
78
"strings"
89
"time"
910

@@ -44,6 +45,11 @@ type BuildAuthTokenOptions struct{}
4445
// See http://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html
4546
// for more information on using IAM database authentication with RDS.
4647
func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds aws.CredentialsProvider, optFns ...func(options *BuildAuthTokenOptions)) (string, error) {
48+
_, port := validateURL(endpoint)
49+
if port == "" {
50+
return "", fmt.Errorf("the provided endpoint is missing a port, or the provided port is invalid")
51+
}
52+
4753
o := BuildAuthTokenOptions{}
4854

4955
for _, fn := range optFns {
@@ -94,3 +100,26 @@ func BuildAuthToken(ctx context.Context, endpoint, region, dbUser string, creds
94100

95101
return url, nil
96102
}
103+
104+
func validateURL(hostPort string) (host, port string) {
105+
colon := strings.LastIndexByte(hostPort, ':')
106+
if colon != -1 {
107+
host, port = hostPort[:colon], hostPort[colon+1:]
108+
}
109+
if !validatePort(port) {
110+
port = ""
111+
return
112+
}
113+
if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
114+
host = host[1 : len(host)-1]
115+
}
116+
117+
return
118+
}
119+
120+
func validatePort(port string) bool {
121+
if _, err := strconv.Atoi(port); err == nil {
122+
return true
123+
}
124+
return false
125+
}

feature/rds/auth/connect_test.go

+35-10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package auth_test
33
import (
44
"context"
55
"regexp"
6+
"strings"
67
"testing"
78

89
"github.com/aws/aws-sdk-go-v2/aws"
@@ -15,27 +16,51 @@ func TestBuildAuthToken(t *testing.T) {
1516
region string
1617
user string
1718
expectedRegex string
19+
expectedError string
1820
}{
1921
{
20-
"https://prod-instance.us-east-1.rds.amazonaws.com:3306",
21-
"us-west-2",
22-
"mysqlUser",
23-
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
22+
endpoint: "https://prod-instance.us-east-1.rds.amazonaws.com:3306",
23+
region: "us-west-2",
24+
user: "mysqlUser",
25+
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
2426
},
2527
{
26-
"prod-instance.us-east-1.rds.amazonaws.com:3306",
27-
"us-west-2",
28-
"mysqlUser",
29-
`^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
28+
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:3306",
29+
region: "us-west-2",
30+
user: "mysqlUser",
31+
expectedRegex: `^prod-instance\.us-east-1\.rds\.amazonaws\.com:3306\?Action=connect.*?DBUser=mysqlUser.*`,
32+
},
33+
{
34+
endpoint: "prod-instance.us-east-1.rds.amazonaws.com",
35+
region: "us-west-2",
36+
user: "mysqlUser",
37+
expectedError: "port",
38+
},
39+
{
40+
endpoint: "prod-instance.us-east-1.rds.amazonaws.com:kakasdkasd",
41+
region: "us-west-2",
42+
user: "mysqlUser",
43+
expectedError: "port",
3044
},
3145
}
3246

3347
for _, c := range cases {
3448
creds := &staticCredentials{AccessKey: "AKID", SecretKey: "SECRET", Session: "SESSION"}
3549
url, err := auth.BuildAuthToken(context.Background(), c.endpoint, c.region, c.user, creds)
36-
if err != nil {
37-
t.Errorf("expect no error, got %v", err)
50+
if len(c.expectedError) > 0 {
51+
if err != nil {
52+
if !strings.Contains(err.Error(), c.expectedError) {
53+
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
54+
} else {
55+
continue
56+
}
57+
} else {
58+
t.Fatalf("expect err: %v, actual err: %v", c.expectedError, err)
59+
}
60+
} else if err != nil {
61+
t.Fatalf("expect no err, got: %v", err)
3862
}
63+
3964
if re, a := regexp.MustCompile(c.expectedRegex), url; !re.MatchString(a) {
4065
t.Errorf("expect %s to match %s", re, a)
4166
}

0 commit comments

Comments
 (0)