Skip to content

Commit e2aec9a

Browse files
Refactor HTTP client creation to be more generic (#113)
1 parent cd9c996 commit e2aec9a

File tree

10 files changed

+96
-80
lines changed

10 files changed

+96
-80
lines changed

cmd/credential_process.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ func generateCredentialProcessConfig(destination string) error {
8989
if destination == "" {
9090
return fmt.Errorf("no destination provided")
9191
}
92-
client, err := creds.GetClient(region)
92+
client, err := creds.GetClient()
9393
if err != nil {
9494
return err
9595
}

cmd/helpers.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ func preInteractiveCheck(region string, client *creds.Client) (*creds.Client, er
197197
// If a client was not provided, create one using the provided region
198198
if client == nil {
199199
var err error
200-
client, err = creds.GetClient(region)
200+
client, err = creds.GetClient()
201201
if err != nil {
202202
return nil, err
203203
}

cmd/list.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ var listCmd = &cobra.Command{
4747
}
4848

4949
func roleList() (string, error) {
50-
client, err := creds.GetClient(region)
50+
client, err := creds.GetClient()
5151
if err != nil {
5252
return "", err
5353
}

cmd/open.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func runOpen(cmd *cobra.Command, args []string) error {
5353
return errors.New("Resource type sns and sqs require region in the arn")
5454
}
5555
var resourceURL string
56-
client, err := creds.GetClient(region)
56+
client, err := creds.GetClient()
5757
if err != nil {
5858
logging.LogError(err, "Error getting client")
5959
return err

pkg/creds/consoleme.go

+10-72
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,15 @@ import (
3030
"strings"
3131
"time"
3232

33+
"github.com/netflix/weep/pkg/httpAuth"
34+
"github.com/netflix/weep/pkg/httpAuth/custom"
35+
3336
"github.com/netflix/weep/pkg/util"
3437

3538
"github.com/netflix/weep/pkg/aws"
3639
"github.com/netflix/weep/pkg/config"
3740
werrors "github.com/netflix/weep/pkg/errors"
3841
"github.com/netflix/weep/pkg/httpAuth/challenge"
39-
"github.com/netflix/weep/pkg/httpAuth/mtls"
4042
"github.com/netflix/weep/pkg/logging"
4143
"github.com/netflix/weep/pkg/metadata"
4244

@@ -48,8 +50,6 @@ import (
4850
var clientVersion = fmt.Sprintf("%s", metadata.Version)
4951

5052
var userAgent = "weep/" + clientVersion + " Go-http-client/1.1"
51-
var clientFactoryOverride ClientFactory
52-
var preflightFunctions = make([]RequestPreflight, 0)
5353

5454
// HTTPClient is the interface we expect HTTP clients to implement.
5555
type HTTPClient interface {
@@ -66,65 +66,15 @@ type Client struct {
6666
Region string
6767
}
6868

69-
type ClientFactory func() (*http.Client, error)
70-
71-
// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client
72-
// creation with a ClientFactory. This function will be called during the creation
73-
// of all ConsoleMe clients.
74-
func RegisterClientFactory(factory ClientFactory) {
75-
clientFactoryOverride = factory
76-
}
77-
78-
type RequestPreflight func(req *http.Request) error
79-
80-
// RegisterRequestPreflight adds a RequestPreflight function which will be called in the
81-
// order of registration during the creation of a ConsoleMe request.
82-
func RegisterRequestPreflight(preflight RequestPreflight) {
83-
preflightFunctions = append(preflightFunctions, preflight)
84-
}
85-
8669
// GetClient creates an authenticated ConsoleMe client
87-
func GetClient(region string) (*Client, error) {
70+
func GetClient() (*Client, error) {
8871
var client *Client
8972
consoleMeUrl := viper.GetString("consoleme_url")
90-
authenticationMethod := viper.GetString("authentication_method")
91-
92-
if clientFactoryOverride != nil {
93-
customClient, err := clientFactoryOverride()
94-
if err != nil {
95-
return client, err
96-
}
97-
client, err = NewClient(consoleMeUrl, "", customClient)
98-
if err != nil {
99-
return client, err
100-
}
101-
} else if authenticationMethod == "mtls" {
102-
mtlsClient, err := mtls.NewHTTPClient()
103-
if err != nil {
104-
return client, err
105-
}
106-
client, err = NewClient(consoleMeUrl, "", mtlsClient)
107-
if err != nil {
108-
return client, err
109-
}
110-
} else if authenticationMethod == "challenge" {
111-
err := challenge.RefreshChallenge()
112-
if err != nil {
113-
return client, err
114-
}
115-
httpClient, err := challenge.NewHTTPClient(consoleMeUrl)
116-
if err != nil {
117-
return client, err
118-
}
119-
client, err = NewClient(consoleMeUrl, "", httpClient)
120-
if err != nil {
121-
return client, err
122-
}
123-
} else {
124-
return nil, fmt.Errorf("Authentication method unsupported or not provided.")
73+
httpClient, err := httpAuth.GetAuthenticatedClient()
74+
if err != nil {
75+
return client, err
12576
}
126-
127-
return client, nil
77+
return NewClient(consoleMeUrl, "", httpClient)
12878
}
12979

13080
// NewClient takes a ConsoleMe hostname and *http.Client, and returns a
@@ -147,18 +97,6 @@ func NewClient(hostname string, region string, httpc *http.Client) (*Client, err
14797
return c, nil
14898
}
14999

150-
func runPreflightFunctions(req *http.Request) error {
151-
var err error
152-
if preflightFunctions != nil {
153-
for _, preflight := range preflightFunctions {
154-
if err = preflight(req); err != nil {
155-
return err
156-
}
157-
}
158-
}
159-
return nil
160-
}
161-
162100
func (c *Client) buildRequest(method string, resource string, body io.Reader, apiPrefix string) (*http.Request, error) {
163101
urlStr := c.Host + apiPrefix + resource
164102
req, err := http.NewRequest(method, urlStr, body)
@@ -167,7 +105,7 @@ func (c *Client) buildRequest(method string, resource string, body io.Reader, ap
167105
}
168106
req.Header.Set("User-Agent", userAgent)
169107
req.Header.Add("Content-Type", "application/json")
170-
err = runPreflightFunctions(req)
108+
err = custom.RunPreflightFunctions(req)
171109
if err != nil {
172110
return nil, err
173111
}
@@ -579,7 +517,7 @@ func GetCredentialsC(client HTTPClient, role string, ipRestrict bool, assumeRole
579517
// GetCredentials requests credentials from ConsoleMe then follows the provided chain of roles to
580518
// assume. Roles are assumed in the order in which they appear in the assumeRole slice.
581519
func GetCredentials(role string, ipRestrict bool, assumeRole []string, region string) (*aws.Credentials, error) {
582-
client, err := GetClient(region)
520+
client, err := GetClient()
583521
if err != nil {
584522
return nil, err
585523
}

pkg/httpAuth/custom/custom.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package custom
2+
3+
import "net/http"
4+
5+
var isOverridden bool
6+
var clientFactoryOverride ClientFactory
7+
var preflightFunctions = make([]RequestPreflight, 0)
8+
9+
type ClientFactory func() (*http.Client, error)
10+
11+
func UseCustom() bool {
12+
return isOverridden
13+
}
14+
15+
func NewHTTPClient() (*http.Client, error) {
16+
return clientFactoryOverride()
17+
}
18+
19+
// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client
20+
// creation with a ClientFactory. This function will be called during the creation
21+
// of all ConsoleMe clients.
22+
func RegisterClientFactory(factory ClientFactory) {
23+
clientFactoryOverride = factory
24+
isOverridden = true
25+
}
26+
27+
type RequestPreflight func(req *http.Request) error
28+
29+
// RegisterRequestPreflight adds a RequestPreflight function which will be called in the
30+
// order of registration during the creation of a ConsoleMe request.
31+
func RegisterRequestPreflight(preflight RequestPreflight) {
32+
preflightFunctions = append(preflightFunctions, preflight)
33+
}
34+
35+
func RunPreflightFunctions(req *http.Request) error {
36+
var err error
37+
if preflightFunctions != nil {
38+
for _, preflight := range preflightFunctions {
39+
if err = preflight(req); err != nil {
40+
return err
41+
}
42+
}
43+
}
44+
return nil
45+
}

pkg/httpAuth/httpAuth.go

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
package httpAuth
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
7+
"github.com/netflix/weep/pkg/httpAuth/challenge"
8+
"github.com/netflix/weep/pkg/httpAuth/custom"
9+
"github.com/netflix/weep/pkg/httpAuth/mtls"
10+
"github.com/spf13/viper"
11+
)
12+
13+
func GetAuthenticatedClient() (*http.Client, error) {
14+
authenticationMethod := viper.GetString("authentication_method")
15+
consoleMeUrl := viper.GetString("consoleme_url")
16+
if custom.UseCustom() {
17+
return custom.NewHTTPClient()
18+
} else if authenticationMethod == "mtls" {
19+
return mtls.NewHTTPClient()
20+
} else if authenticationMethod == "challenge" {
21+
err := challenge.RefreshChallenge()
22+
if err != nil {
23+
return nil, err
24+
}
25+
return challenge.NewHTTPClient(consoleMeUrl)
26+
}
27+
return nil, fmt.Errorf("Authentication method unsupported or not provided.")
28+
}

pkg/server/ecsCredentialsHandler.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func parseAssumeRoleQuery(r *http.Request) ([]string, error) {
5656

5757
func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request) {
5858
return func(w http.ResponseWriter, r *http.Request) {
59-
var client, err = creds.GetClient(region)
59+
var client, err = creds.GetClient()
6060
if err != nil {
6161
logging.Log.Error(err)
6262
util.WriteError(w, err.Error(), http.StatusBadRequest)

pkg/server/server.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er
3131

3232
if isServingIMDS {
3333
logging.Log.Infof("Configuring weep IMDS service for role %s", role)
34-
client, err := creds.GetClient(region)
34+
client, err := creds.GetClient()
3535
if err != nil {
3636
return err
3737
}

pkg/swag/swag.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ import (
55
"fmt"
66
"net/http"
77

8-
"github.com/netflix/weep/pkg/httpAuth/mtls"
8+
"github.com/netflix/weep/pkg/creds"
9+
910
"github.com/spf13/viper"
1011
)
1112

@@ -15,7 +16,11 @@ type SwagResponse struct {
1516

1617
func getClient() (*http.Client, error) {
1718
if viper.GetBool("swag.use_mtls") {
18-
return mtls.NewHTTPClient()
19+
consoleMeClient, err := creds.GetClient()
20+
if err != nil {
21+
return nil, err
22+
}
23+
return &consoleMeClient.Client, nil
1924
}
2025
return http.DefaultClient, nil
2126
}

0 commit comments

Comments
 (0)