-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add a helper RoundTripper for ignoring unhandleable critical certific…
…ate extensions during TLS negotiation. (#130) * Add the ignore unknown extensions helper to httputil * rename * package * go.mod, tests * disable test that won't work without a hosts entry * comment * naming * license * fumpt * Update httputil/cert_ext_tripper.go Co-authored-by: Jeff Mitchell <[email protected]> * address PR feedback * more feedback, IPv6 testing * Add port to the matrix * Add module to CI --------- Co-authored-by: Jeff Mitchell <[email protected]>
- Loading branch information
Showing
5 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ jobs: | |
"configutil", | ||
"fileutil", | ||
"gatedwriter", | ||
"httputil", | ||
"kv-builder", | ||
"listenerutil", | ||
"mlock", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Copyright (c) HashiCorp, Inc. | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
package httputil | ||
|
||
import ( | ||
"crypto/tls" | ||
"crypto/x509" | ||
"encoding/asn1" | ||
"errors" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"strings" | ||
) | ||
|
||
type ignoreExtensionsRoundTripper struct { | ||
base *http.Transport | ||
extsToIgnore []asn1.ObjectIdentifier | ||
} | ||
|
||
// NewIgnoreUnhandledExtensionsRoundTripper creates a RoundTripper that may be used in an HTTP client which will | ||
// ignore the provided extensions if presently unhandled on a certificate. If base is nil, the default RoundTripper is used. | ||
func NewIgnoreUnhandledExtensionsRoundTripper(base http.RoundTripper, extsToIgnore []asn1.ObjectIdentifier) (http.RoundTripper, error) { | ||
if len(extsToIgnore) == 0 { | ||
return nil, errors.New("no extensions ignored, should use original RoundTripper") | ||
} | ||
if base == nil { | ||
base = http.DefaultTransport | ||
} | ||
|
||
tp, ok := base.(*http.Transport) | ||
if !ok { | ||
// We don't know how to deal with this object, bail | ||
return base, nil | ||
} | ||
if tp != nil && (tp.TLSClientConfig != nil && (tp.TLSClientConfig.InsecureSkipVerify || tp.TLSClientConfig.VerifyConnection != nil)) { | ||
// Already not verifying or verifying in a custom fashion | ||
return nil, errors.New("cannot ignore provided extensions, base RoundTripper already handling or skipping verification") | ||
} | ||
return &ignoreExtensionsRoundTripper{base: tp, extsToIgnore: extsToIgnore}, nil | ||
} | ||
|
||
func (i *ignoreExtensionsRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) { | ||
domain, _, err := net.SplitHostPort(request.URL.Host) | ||
if err != nil { | ||
if strings.Contains(err.Error(), "missing port") { | ||
domain = request.URL.Host | ||
} else { | ||
return nil, fmt.Errorf("error splitting host/port: %w", err) | ||
} | ||
} | ||
|
||
var tlsConfig *tls.Config | ||
perReqTransport := i.base.Clone() | ||
if perReqTransport.TLSClientConfig != nil { | ||
tlsConfig = perReqTransport.TLSClientConfig.Clone() | ||
} else { | ||
tlsConfig = &tls.Config{} | ||
} | ||
|
||
// Domain may be an IP address, in which case we shouldn't set ServerName | ||
var ipBased bool | ||
if addr := net.ParseIP(domain); addr == nil { | ||
tlsConfig.ServerName = domain | ||
} else { | ||
ipBased = true | ||
} | ||
|
||
tlsConfig.InsecureSkipVerify = true | ||
connectionVerifier := i.customVerifyConnection(tlsConfig, ipBased) | ||
tlsConfig.VerifyConnection = connectionVerifier | ||
|
||
perReqTransport.TLSClientConfig = tlsConfig | ||
return perReqTransport.RoundTrip(request) | ||
} | ||
|
||
func (i *ignoreExtensionsRoundTripper) customVerifyConnection(tc *tls.Config, ipBased bool) func(tls.ConnectionState) error { | ||
return func(cs tls.ConnectionState) error { | ||
certs := cs.PeerCertificates | ||
|
||
serverName := cs.ServerName | ||
if cs.ServerName == "" && !ipBased { | ||
if tc.ServerName == "" { | ||
return fmt.Errorf("the ServerName in TLSClientConfig is required to be set when UnhandledExtensionsToIgnore has values") | ||
} | ||
serverName = tc.ServerName | ||
} else if cs.ServerName != tc.ServerName { | ||
return fmt.Errorf("connection state server name (%s) does not match requested (%s)", cs.ServerName, tc.ServerName) | ||
} | ||
|
||
for _, cert := range certs { | ||
if len(cert.UnhandledCriticalExtensions) == 0 { | ||
continue | ||
} | ||
var remainingUnhandled []asn1.ObjectIdentifier | ||
for _, ext := range cert.UnhandledCriticalExtensions { | ||
shouldRemove := i.isExtInIgnore(ext) | ||
if !shouldRemove { | ||
remainingUnhandled = append(remainingUnhandled, ext) | ||
} | ||
} | ||
cert.UnhandledCriticalExtensions = remainingUnhandled | ||
} | ||
|
||
// Now verify with the requested extensions removed | ||
opts := x509.VerifyOptions{ | ||
Roots: tc.RootCAs, | ||
DNSName: serverName, | ||
Intermediates: x509.NewCertPool(), | ||
} | ||
|
||
for _, cert := range certs[1:] { | ||
opts.Intermediates.AddCert(cert) | ||
} | ||
|
||
_, err := certs[0].Verify(opts) | ||
if err != nil { | ||
return &tls.CertificateVerificationError{UnverifiedCertificates: certs, Err: err} | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
func (i *ignoreExtensionsRoundTripper) isExtInIgnore(ext asn1.ObjectIdentifier) bool { | ||
for _, extToIgnore := range i.extsToIgnore { | ||
if ext.Equal(extToIgnore) { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
// Copyright (c) HashiCorp, Inc. | ||
// SPDX-License-Identifier: MPL-2.0 | ||
|
||
package httputil | ||
|
||
import ( | ||
"crypto/ecdsa" | ||
"crypto/elliptic" | ||
"crypto/rand" | ||
"crypto/tls" | ||
"crypto/x509" | ||
"crypto/x509/pkix" | ||
"encoding/asn1" | ||
"fmt" | ||
"math/big" | ||
"net" | ||
"net/http" | ||
"net/http/httptest" | ||
"net/url" | ||
"strings" | ||
"testing" | ||
"time" | ||
) | ||
|
||
var ( | ||
inhibitAnyPolicyExt = asn1.ObjectIdentifier{2, 5, 29, 54} | ||
policyConstraintExt = asn1.ObjectIdentifier{2, 5, 29, 36} | ||
) | ||
|
||
func TestClient(t *testing.T) { | ||
for _, host := range []string{"localhost", "127.0.0.1", "[::1]"} { | ||
func() { | ||
srv := newTLSServer(t, true, host) | ||
defer srv.Close() | ||
runOverrideTests(host, t, srv) | ||
|
||
srv2 := newTLSServer(t, true, host) | ||
defer srv2.Close() | ||
|
||
url, err := url.Parse(srv2.URL) | ||
if err != nil { | ||
t.Fatalf("err parsing server address: %v", err) | ||
} | ||
|
||
runOverrideTests(url.Host, t, srv2) | ||
}() | ||
} | ||
} | ||
|
||
func runOverrideTests(host string, t *testing.T, srv *httptest.Server) { | ||
tests := []struct { | ||
name string | ||
extsToIgnore []asn1.ObjectIdentifier | ||
errContains string | ||
}{ | ||
{ | ||
name: fmt.Sprintf("no-overrides-[%s]", host), | ||
errContains: "no extensions ignored", | ||
}, | ||
{ | ||
name: fmt.Sprintf("partial-override-[%s]", host), | ||
extsToIgnore: []asn1.ObjectIdentifier{inhibitAnyPolicyExt}, | ||
errContains: "x509: unhandled critical extension", | ||
}, | ||
{ | ||
name: fmt.Sprintf("full-override-[%s]", host), | ||
extsToIgnore: []asn1.ObjectIdentifier{inhibitAnyPolicyExt, policyConstraintExt}, | ||
}, | ||
} | ||
|
||
for _, tc := range tests { | ||
tc := tc | ||
t.Run(tc.name, func(t *testing.T) { | ||
client, err := getClient(t, srv, tc.extsToIgnore) | ||
if err != nil { | ||
if tc.errContains == "" { | ||
t.Fatalf("unexpected error: %v", err) | ||
} else if !strings.Contains(err.Error(), tc.errContains) { | ||
t.Fatalf("expected error to contain '%s', got '%s'", tc.errContains, err.Error()) | ||
} else { | ||
return | ||
} | ||
} | ||
resp, err := client.Get(srv.URL) | ||
if len(tc.errContains) > 0 { | ||
if err == nil { | ||
t.Fatal("expected error, got nil") | ||
} | ||
if !strings.Contains(err.Error(), tc.errContains) { | ||
t.Fatalf("expected error to contain '%s', got '%s'", tc.errContains, err.Error()) | ||
} | ||
} else { | ||
if err != nil { | ||
t.Fatalf("unexpected error: %s", err) | ||
} | ||
|
||
defer func() { _ = resp.Body.Close() }() | ||
if resp.StatusCode != http.StatusOK { | ||
t.Fatalf("got status code: %v", resp.StatusCode) | ||
} | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func getClient(t *testing.T, srv *httptest.Server, extsToIgnore []asn1.ObjectIdentifier) (*http.Client, error) { | ||
srvCertsRaw := srv.TLS.Certificates[0] | ||
rootCert, err := x509.ParseCertificate(srvCertsRaw.Certificate[0]) | ||
if err != nil { | ||
return nil, fmt.Errorf("failed parsing root ca certificate: %v", err) | ||
} | ||
|
||
certpool := x509.NewCertPool() | ||
certpool.AddCert(rootCert) | ||
rt, err := NewIgnoreUnhandledExtensionsRoundTripper(&http.Transport{ | ||
TLSClientConfig: &tls.Config{ | ||
RootCAs: certpool, | ||
}, | ||
}, extsToIgnore) | ||
if err != nil { | ||
return nil, fmt.Errorf("error instantiating round tripper: %v", err) | ||
} | ||
|
||
client := http.Client{ | ||
Transport: rt, | ||
} | ||
return &client, nil | ||
} | ||
|
||
func newTLSServer(t *testing.T, withUnsupportedExts bool, hostname string) *httptest.Server { | ||
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { | ||
defer func() { _ = req.Body.Close() }() | ||
w.WriteHeader(http.StatusOK) | ||
_, _ = w.Write([]byte("Hello World!")) | ||
})) | ||
|
||
// hack to force listening on IPv6 | ||
if hostname[0] == '[' { | ||
if l, err := net.Listen("tcp6", "[::1]:0"); err != nil { | ||
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err)) | ||
} else { | ||
ts.Listener = l | ||
} | ||
} | ||
|
||
ts.TLS = &tls.Config{Certificates: []tls.Certificate{getSelfSignedRoot(t, withUnsupportedExts)}} | ||
ts.StartTLS() | ||
ts.URL = strings.Replace(ts.URL, "127.0.0.1", hostname, 1) | ||
return ts | ||
} | ||
|
||
func getSelfSignedRoot(t *testing.T, withUnsupportedExts bool) tls.Certificate { | ||
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | ||
if err != nil { | ||
t.Fatalf("failed to generate private key: %v", err) | ||
} | ||
pub := key.Public() | ||
|
||
inhibitExt := pkix.Extension{ | ||
Id: inhibitAnyPolicyExt, | ||
Critical: true, | ||
Value: []byte{2, 1, 0}, | ||
} | ||
|
||
polConstraint := pkix.Extension{ | ||
Id: policyConstraintExt, | ||
Critical: true, | ||
Value: []byte{48, 6, 128, 1, 0, 129, 1, 0}, | ||
} | ||
|
||
caTemplate := &x509.Certificate{ | ||
Subject: pkix.Name{CommonName: "Root CA with bad extensions"}, | ||
SerialNumber: big.NewInt(1), | ||
NotBefore: time.Now().Add(-5 * time.Minute), | ||
NotAfter: time.Now().Add(10 * time.Minute), | ||
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign, | ||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, | ||
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, | ||
} | ||
if withUnsupportedExts { | ||
caTemplate.ExtraExtensions = []pkix.Extension{polConstraint, inhibitExt} | ||
caTemplate.DNSNames = []string{"localhost"} | ||
} else { | ||
caTemplate.DNSNames = []string{"example.com"} | ||
} | ||
|
||
caBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, pub, key) | ||
if err != nil { | ||
t.Fatalf("failed to marshal CA certificate: %v", err) | ||
} | ||
|
||
return tls.Certificate{ | ||
Certificate: [][]byte{caBytes}, | ||
PrivateKey: key, | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
module github.com/hashicorp/go-secure-stdlib/httputil | ||
|
||
go 1.22.5 | ||
|
||
require ( | ||
github.com/fatih/color v1.13.0 // indirect | ||
github.com/hashicorp/go-hclog v1.6.3 // indirect | ||
github.com/mattn/go-colorable v0.1.12 // indirect | ||
github.com/mattn/go-isatty v0.0.14 // indirect | ||
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= | ||
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= | ||
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= | ||
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= | ||
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= | ||
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= | ||
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= | ||
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= | ||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= | ||
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= | ||
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= | ||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= | ||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= | ||
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= | ||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= | ||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60= | ||
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= | ||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= | ||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= |