Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WithDialTLSContextFunc option #324

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,24 +414,36 @@ func (a *BasicAuth) UnmarshalYAML(unmarshal func(interface{}) error) error {
// by net.Dialer.
type DialContextFunc func(context.Context, string, string) (net.Conn, error)

// DialTLSContextFunc defines the signature of the DialContext() function implemented
// by tls.Dialer.
type DialTLSContextFunc func(context.Context, string, string) (net.Conn, error)

type httpClientOptions struct {
dialContextFunc DialContextFunc
keepAlivesEnabled bool
http2Enabled bool
idleConnTimeout time.Duration
userAgent string
dialContextFunc DialContextFunc
dialTLSContextFunc DialTLSContextFunc
keepAlivesEnabled bool
http2Enabled bool
idleConnTimeout time.Duration
userAgent string
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
type HTTPClientOption func(options *httpClientOptions)

// WithDialContextFunc allows you to override func gets used for the actual dialing. The default is `net.Dialer.DialContext`.
// WithDialContextFunc allows you to override the func that gets used for the actual dialing. The default is `net.Dialer.DialContext`.
func WithDialContextFunc(fn DialContextFunc) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.dialContextFunc = fn
}
}

// WithDialTLSContextFunc allows you to override the func that gets used for the actual dialing. The default is `tls.Dialer.DialContext`.
func WithDialTLSContextFunc(fn DialTLSContextFunc) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.dialTLSContextFunc = fn
}
}

// WithKeepAlivesDisabled allows to disable HTTP keepalive.
func WithKeepAlivesDisabled() HTTPClientOption {
return func(opts *httpClientOptions) {
Expand Down Expand Up @@ -519,6 +531,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
DialContext: dialContext,
DialTLSContext: opts.dialTLSContextFunc,
}
if opts.http2Enabled && cfg.EnableHTTP2 {
// HTTP/2 support is golang had many problematic cornercases where
Expand Down
17 changes: 17 additions & 0 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,23 @@ func TestCustomDialContextFunc(t *testing.T) {
}
}

func TestCustomDialTLSContextFunc(t *testing.T) {
dialFn := func(_ context.Context, _, _ string) (net.Conn, error) {
return nil, errors.New(ExpectedError)
}

cfg := HTTPClientConfig{}
client, err := NewClientFromConfig(cfg, "test", WithDialTLSContextFunc(dialFn))
if err != nil {
t.Fatalf("Can't create a client from this config: %+v", cfg)
}

_, err = client.Get("https://localhost")
if err == nil || !strings.Contains(err.Error(), ExpectedError) {
t.Errorf("Expected error %q but got %q", ExpectedError, err)
}
}

func TestCustomIdleConnTimeout(t *testing.T) {
timeout := time.Second * 5

Expand Down
4 changes: 4 additions & 0 deletions model/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ const (
// timeout used to scrape a target.
ScrapeTimeoutLabel = "__scrape_timeout__"

// ServerNameLabel is the name of the label that holds the TLS server name
// used to scrape the target.
ServerNameLabel = "__server_name__"

// ReservedLabelPrefix is a prefix which is not legal in user-supplied
// label names.
ReservedLabelPrefix = "__"
Expand Down