Skip to content

Commit cd9c996

Browse files
Add support for custom HTTP client & preflight (#112)
* Add support for custom HTTP client & preflight * add some comments * make the mtls switch work better
1 parent e08ee3c commit cd9c996

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

pkg/config/config.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ func SetUser(user string) error {
160160

161161
func MtlsEnabled() bool {
162162
authMethod := viper.GetString("authentication_method")
163-
return authMethod == "mtls"
163+
return authMethod == "mtls" && viper.GetBool("mtls_settings.enabled")
164164
}
165165

166166
// BaseWebURL allows the ConsoleMe URL to be overridden for cases where the API

pkg/creds/consoleme.go

+45-4
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,8 @@ import (
4848
var clientVersion = fmt.Sprintf("%s", metadata.Version)
4949

5050
var userAgent = "weep/" + clientVersion + " Go-http-client/1.1"
51-
52-
type Account struct {
53-
}
51+
var clientFactoryOverride ClientFactory
52+
var preflightFunctions = make([]RequestPreflight, 0)
5453

5554
// HTTPClient is the interface we expect HTTP clients to implement.
5655
type HTTPClient interface {
@@ -67,13 +66,39 @@ type Client struct {
6766
Region string
6867
}
6968

69+
type ClientFactory func() (*http.Client, error)
70+
71+
// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client
72+
// creation with a ClientFactory. This function will be called during the creation
73+
// of all ConsoleMe clients.
74+
func RegisterClientFactory(factory ClientFactory) {
75+
clientFactoryOverride = factory
76+
}
77+
78+
type RequestPreflight func(req *http.Request) error
79+
80+
// RegisterRequestPreflight adds a RequestPreflight function which will be called in the
81+
// order of registration during the creation of a ConsoleMe request.
82+
func RegisterRequestPreflight(preflight RequestPreflight) {
83+
preflightFunctions = append(preflightFunctions, preflight)
84+
}
85+
7086
// GetClient creates an authenticated ConsoleMe client
7187
func GetClient(region string) (*Client, error) {
7288
var client *Client
7389
consoleMeUrl := viper.GetString("consoleme_url")
7490
authenticationMethod := viper.GetString("authentication_method")
7591

76-
if authenticationMethod == "mtls" {
92+
if clientFactoryOverride != nil {
93+
customClient, err := clientFactoryOverride()
94+
if err != nil {
95+
return client, err
96+
}
97+
client, err = NewClient(consoleMeUrl, "", customClient)
98+
if err != nil {
99+
return client, err
100+
}
101+
} else if authenticationMethod == "mtls" {
77102
mtlsClient, err := mtls.NewHTTPClient()
78103
if err != nil {
79104
return client, err
@@ -122,6 +147,18 @@ func NewClient(hostname string, region string, httpc *http.Client) (*Client, err
122147
return c, nil
123148
}
124149

150+
func runPreflightFunctions(req *http.Request) error {
151+
var err error
152+
if preflightFunctions != nil {
153+
for _, preflight := range preflightFunctions {
154+
if err = preflight(req); err != nil {
155+
return err
156+
}
157+
}
158+
}
159+
return nil
160+
}
161+
125162
func (c *Client) buildRequest(method string, resource string, body io.Reader, apiPrefix string) (*http.Request, error) {
126163
urlStr := c.Host + apiPrefix + resource
127164
req, err := http.NewRequest(method, urlStr, body)
@@ -130,6 +167,10 @@ func (c *Client) buildRequest(method string, resource string, body io.Reader, ap
130167
}
131168
req.Header.Set("User-Agent", userAgent)
132169
req.Header.Add("Content-Type", "application/json")
170+
err = runPreflightFunctions(req)
171+
if err != nil {
172+
return nil, err
173+
}
133174

134175
return req, nil
135176
}

0 commit comments

Comments
 (0)