Skip to content

Commit

Permalink
Add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
kmala committed Jul 11, 2024
1 parent 6c15092 commit e7daa8f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 4 deletions.
8 changes: 4 additions & 4 deletions pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ type tokenVerifier struct {
validSTShostnames map[string]bool
}

func getDefaultHostNameForRegion(partition *endpoints.Partition, region string) (string, error) {
rep, err := partition.EndpointFor(stsServiceID, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
func getDefaultHostNameForRegion(partition *endpoints.Partition, region, service string) (string, error) {
rep, err := partition.EndpointFor(service, region, endpoints.STSRegionalEndpointOption, endpoints.ResolveUnknownServiceOption)
if err != nil {
return "", fmt.Errorf("Error resolving endpoint for %s in partition %s. err: %v", region, partition.ID(), err)
}
Expand Down Expand Up @@ -410,7 +410,7 @@ func stsHostsForPartition(partitionID, region string) map[string]bool {
logrus.Errorf("STS service not found in partition %s", partitionID)
// Add the host of the current instances region if the service doesn't already exists in the partition
// so we don't fail if the service is not present in the go sdk but matches the instances region.
stsHostName, err := getDefaultHostNameForRegion(partition, region)
stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID)
if err != nil {
logrus.WithError(err).Error("Error getting default hostname")
} else {
Expand All @@ -436,7 +436,7 @@ func stsHostsForPartition(partitionID, region string) map[string]bool {
// Add the host of the current instances region if not already exists so we don't fail if the region is not
// present in the go sdk but matches the instances region.
if _, ok := stsSvcEndPoints[region]; !ok {
stsHostName, err := getDefaultHostNameForRegion(partition, region)
stsHostName, err := getDefaultHostNameForRegion(partition, region, stsServiceID)
if err != nil {
logrus.WithError(err).Error("Error getting default hostname")
return validSTShostnames
Expand Down
68 changes: 68 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/google/go-cmp/cmp"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -501,3 +502,70 @@ func response(account, userID, arn string) getCallerIdentityWrapper {
wrapper.GetCallerIdentityResponse.ResponseMetadata.RequestID = "id1234"
return wrapper
}

func Test_getDefaultHostNameForRegion(t *testing.T) {
type args struct {
partition endpoints.Partition
region string
service string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{
name: "service doesn't exist should return default host name",
args: args{
partition: endpoints.AwsIsoEPartition(),
region: "eu-isoe-west-1",
service: "test",
},
want: "test.eu-isoe-west-1.cloud.adc-e.uk",
wantErr: false,
},
{
name: "service and region doesn't exist should return default host name",
args: args{
partition: endpoints.AwsIsoEPartition(),
region: "eu-isoe-test-1",
service: "test",
},
want: "test.eu-isoe-test-1.cloud.adc-e.uk",
wantErr: false,
},
{
name: "region doesn't exist should return default host name",
args: args{
partition: endpoints.AwsIsoPartition(),
region: "us-iso-test-1",
service: "sts",
},
want: "sts.us-iso-test-1.c2s.ic.gov",
wantErr: false,
},
{
name: "invalid region should return error",
args: args{
partition: endpoints.AwsIsoPartition(),
region: "test_123",
service: "sts",
},
want: "",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getDefaultHostNameForRegion(&tt.args.partition, tt.args.region, tt.args.service)
if (err != nil) != tt.wantErr {
t.Errorf("getDefaultHostNameForRegion() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getDefaultHostNameForRegion() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit e7daa8f

Please sign in to comment.