Skip to content

Commit

Permalink
Add a helper RoundTripper for ignoring unhandleable critical certific…
Browse files Browse the repository at this point in the history
…ate extensions during TLS negotiation. (#130)

* Add the ignore unknown extensions helper to httputil

* rename

* package

* go.mod, tests

* disable test that won't work without a hosts entry

* comment

* naming

* license

* fumpt

* Update httputil/cert_ext_tripper.go

Co-authored-by: Jeff Mitchell <[email protected]>

* address PR feedback

* more feedback, IPv6 testing

* Add port to the matrix

* Add module to CI

---------

Co-authored-by: Jeff Mitchell <[email protected]>
  • Loading branch information
sgmiller and jefferai authored Aug 2, 2024
1 parent fab9dfb commit 62edfce
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jobs:
"configutil",
"fileutil",
"gatedwriter",
"httputil",
"kv-builder",
"listenerutil",
"mlock",
Expand Down
134 changes: 134 additions & 0 deletions httputil/cert_ext_tripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package httputil

import (
"crypto/tls"
"crypto/x509"
"encoding/asn1"
"errors"
"fmt"
"net"
"net/http"
"strings"
)

type ignoreExtensionsRoundTripper struct {
base *http.Transport
extsToIgnore []asn1.ObjectIdentifier
}

// NewIgnoreUnhandledExtensionsRoundTripper creates a RoundTripper that may be used in an HTTP client which will
// ignore the provided extensions if presently unhandled on a certificate. If base is nil, the default RoundTripper is used.
func NewIgnoreUnhandledExtensionsRoundTripper(base http.RoundTripper, extsToIgnore []asn1.ObjectIdentifier) (http.RoundTripper, error) {
if len(extsToIgnore) == 0 {
return nil, errors.New("no extensions ignored, should use original RoundTripper")
}
if base == nil {
base = http.DefaultTransport
}

tp, ok := base.(*http.Transport)
if !ok {
// We don't know how to deal with this object, bail
return base, nil
}
if tp != nil && (tp.TLSClientConfig != nil && (tp.TLSClientConfig.InsecureSkipVerify || tp.TLSClientConfig.VerifyConnection != nil)) {
// Already not verifying or verifying in a custom fashion
return nil, errors.New("cannot ignore provided extensions, base RoundTripper already handling or skipping verification")
}
return &ignoreExtensionsRoundTripper{base: tp, extsToIgnore: extsToIgnore}, nil
}

func (i *ignoreExtensionsRoundTripper) RoundTrip(request *http.Request) (*http.Response, error) {
domain, _, err := net.SplitHostPort(request.URL.Host)
if err != nil {
if strings.Contains(err.Error(), "missing port") {
domain = request.URL.Host
} else {
return nil, fmt.Errorf("error splitting host/port: %w", err)
}
}

var tlsConfig *tls.Config
perReqTransport := i.base.Clone()
if perReqTransport.TLSClientConfig != nil {
tlsConfig = perReqTransport.TLSClientConfig.Clone()
} else {
tlsConfig = &tls.Config{}
}

// Domain may be an IP address, in which case we shouldn't set ServerName
var ipBased bool
if addr := net.ParseIP(domain); addr == nil {
tlsConfig.ServerName = domain
} else {
ipBased = true
}

tlsConfig.InsecureSkipVerify = true
connectionVerifier := i.customVerifyConnection(tlsConfig, ipBased)
tlsConfig.VerifyConnection = connectionVerifier

perReqTransport.TLSClientConfig = tlsConfig
return perReqTransport.RoundTrip(request)
}

func (i *ignoreExtensionsRoundTripper) customVerifyConnection(tc *tls.Config, ipBased bool) func(tls.ConnectionState) error {
return func(cs tls.ConnectionState) error {
certs := cs.PeerCertificates

serverName := cs.ServerName
if cs.ServerName == "" && !ipBased {
if tc.ServerName == "" {
return fmt.Errorf("the ServerName in TLSClientConfig is required to be set when UnhandledExtensionsToIgnore has values")
}
serverName = tc.ServerName
} else if cs.ServerName != tc.ServerName {
return fmt.Errorf("connection state server name (%s) does not match requested (%s)", cs.ServerName, tc.ServerName)
}

for _, cert := range certs {
if len(cert.UnhandledCriticalExtensions) == 0 {
continue
}
var remainingUnhandled []asn1.ObjectIdentifier
for _, ext := range cert.UnhandledCriticalExtensions {
shouldRemove := i.isExtInIgnore(ext)
if !shouldRemove {
remainingUnhandled = append(remainingUnhandled, ext)
}
}
cert.UnhandledCriticalExtensions = remainingUnhandled
}

// Now verify with the requested extensions removed
opts := x509.VerifyOptions{
Roots: tc.RootCAs,
DNSName: serverName,
Intermediates: x509.NewCertPool(),
}

for _, cert := range certs[1:] {
opts.Intermediates.AddCert(cert)
}

_, err := certs[0].Verify(opts)
if err != nil {
return &tls.CertificateVerificationError{UnverifiedCertificates: certs, Err: err}
}

return nil
}
}

func (i *ignoreExtensionsRoundTripper) isExtInIgnore(ext asn1.ObjectIdentifier) bool {
for _, extToIgnore := range i.extsToIgnore {
if ext.Equal(extToIgnore) {
return true
}
}

return false
}
196 changes: 196 additions & 0 deletions httputil/cert_ext_tripper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package httputil

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
)

var (
inhibitAnyPolicyExt = asn1.ObjectIdentifier{2, 5, 29, 54}
policyConstraintExt = asn1.ObjectIdentifier{2, 5, 29, 36}
)

func TestClient(t *testing.T) {
for _, host := range []string{"localhost", "127.0.0.1", "[::1]"} {
func() {
srv := newTLSServer(t, true, host)
defer srv.Close()
runOverrideTests(host, t, srv)

srv2 := newTLSServer(t, true, host)
defer srv2.Close()

url, err := url.Parse(srv2.URL)
if err != nil {
t.Fatalf("err parsing server address: %v", err)
}

runOverrideTests(url.Host, t, srv2)
}()
}
}

func runOverrideTests(host string, t *testing.T, srv *httptest.Server) {
tests := []struct {
name string
extsToIgnore []asn1.ObjectIdentifier
errContains string
}{
{
name: fmt.Sprintf("no-overrides-[%s]", host),
errContains: "no extensions ignored",
},
{
name: fmt.Sprintf("partial-override-[%s]", host),
extsToIgnore: []asn1.ObjectIdentifier{inhibitAnyPolicyExt},
errContains: "x509: unhandled critical extension",
},
{
name: fmt.Sprintf("full-override-[%s]", host),
extsToIgnore: []asn1.ObjectIdentifier{inhibitAnyPolicyExt, policyConstraintExt},
},
}

for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
client, err := getClient(t, srv, tc.extsToIgnore)
if err != nil {
if tc.errContains == "" {
t.Fatalf("unexpected error: %v", err)
} else if !strings.Contains(err.Error(), tc.errContains) {
t.Fatalf("expected error to contain '%s', got '%s'", tc.errContains, err.Error())
} else {
return
}
}
resp, err := client.Get(srv.URL)
if len(tc.errContains) > 0 {
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), tc.errContains) {
t.Fatalf("expected error to contain '%s', got '%s'", tc.errContains, err.Error())
}
} else {
if err != nil {
t.Fatalf("unexpected error: %s", err)
}

defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
t.Fatalf("got status code: %v", resp.StatusCode)
}
}
})
}
}

func getClient(t *testing.T, srv *httptest.Server, extsToIgnore []asn1.ObjectIdentifier) (*http.Client, error) {
srvCertsRaw := srv.TLS.Certificates[0]
rootCert, err := x509.ParseCertificate(srvCertsRaw.Certificate[0])
if err != nil {
return nil, fmt.Errorf("failed parsing root ca certificate: %v", err)
}

certpool := x509.NewCertPool()
certpool.AddCert(rootCert)
rt, err := NewIgnoreUnhandledExtensionsRoundTripper(&http.Transport{
TLSClientConfig: &tls.Config{
RootCAs: certpool,
},
}, extsToIgnore)
if err != nil {
return nil, fmt.Errorf("error instantiating round tripper: %v", err)
}

client := http.Client{
Transport: rt,
}
return &client, nil
}

func newTLSServer(t *testing.T, withUnsupportedExts bool, hostname string) *httptest.Server {
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
defer func() { _ = req.Body.Close() }()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("Hello World!"))
}))

// hack to force listening on IPv6
if hostname[0] == '[' {
if l, err := net.Listen("tcp6", "[::1]:0"); err != nil {
panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
} else {
ts.Listener = l
}
}

ts.TLS = &tls.Config{Certificates: []tls.Certificate{getSelfSignedRoot(t, withUnsupportedExts)}}
ts.StartTLS()
ts.URL = strings.Replace(ts.URL, "127.0.0.1", hostname, 1)
return ts
}

func getSelfSignedRoot(t *testing.T, withUnsupportedExts bool) tls.Certificate {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("failed to generate private key: %v", err)
}
pub := key.Public()

inhibitExt := pkix.Extension{
Id: inhibitAnyPolicyExt,
Critical: true,
Value: []byte{2, 1, 0},
}

polConstraint := pkix.Extension{
Id: policyConstraintExt,
Critical: true,
Value: []byte{48, 6, 128, 1, 0, 129, 1, 0},
}

caTemplate := &x509.Certificate{
Subject: pkix.Name{CommonName: "Root CA with bad extensions"},
SerialNumber: big.NewInt(1),
NotBefore: time.Now().Add(-5 * time.Minute),
NotAfter: time.Now().Add(10 * time.Minute),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
}
if withUnsupportedExts {
caTemplate.ExtraExtensions = []pkix.Extension{polConstraint, inhibitExt}
caTemplate.DNSNames = []string{"localhost"}
} else {
caTemplate.DNSNames = []string{"example.com"}
}

caBytes, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, pub, key)
if err != nil {
t.Fatalf("failed to marshal CA certificate: %v", err)
}

return tls.Certificate{
Certificate: [][]byte{caBytes},
PrivateKey: key,
}
}
11 changes: 11 additions & 0 deletions httputil/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/hashicorp/go-secure-stdlib/httputil

go 1.22.5

require (
github.com/fatih/color v1.13.0 // indirect
github.com/hashicorp/go-hclog v1.6.3 // indirect
github.com/mattn/go-colorable v0.1.12 // indirect
github.com/mattn/go-isatty v0.0.14 // indirect
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 // indirect
)
23 changes: 23 additions & 0 deletions httputil/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w=
github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk=
github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k=
github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6 h1:nonptSpoQ4vQjyraW20DXPAglgQfVnM9ZC6MmNLMR60=
golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

0 comments on commit 62edfce

Please sign in to comment.