Skip to content

Commit

Permalink
fix nits and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhenLian committed Dec 20, 2019
1 parent 43d11d8 commit 850dad9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 41 deletions.
62 changes: 30 additions & 32 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ type sysConn = syscall.Conn
type CustomVerificationFunc func(serverName string,
rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error

// RootCertificateOptions contains a field and a function for obtaining root trust certificates.
// It is used by both ClientOptions and ServerOptions.
type RootCertificateOptions struct {
// ---------------General Rules For Root Certificate Setting------------------------------------
// RootCACerts or GetRootCAs indicates the the certificates trusted by the client side. The rules
// for setting these two fields are:
// Either RootCACerts or GetRootCAs must be set; the other will be ignored
//----------------------------------------------------------------------------------------------
// If field RootCACerts is set, field GetRootCAs will be ignored. The client will use RootCACerts
// every time when verifying the peer certificates, without performing root certificate reloading.
RootCACerts *x509.CertPool
// If GetRootCAs is set and RootCACerts is nil, the client will invoke this function every time
// asked to check certificates sent from the server when a new connection is established.
// This is known as root CA certificate reloading.
GetRootCAs func(rawConn net.Conn, rawCerts [][]byte) (*x509.CertPool, error)
}

// ClientOptions contains all the fields and functions needed to be filled by the client.
type ClientOptions struct {
// ---------------General Rules For Certificate Setting-----------------------------------------
Expand All @@ -58,25 +75,14 @@ type ClientOptions struct {
// function every time asked to present certificates to the server when a new connection is
// established. This is known as peer certificate reloading.
GetClientCertificate func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
// ---------------General Rules For Root Certificate Setting------------------------------------
// RootCACerts or GetRootCAs indicates the the certificates trusted by the client side. The rules
// for setting these two fields are:
// Either RootCACerts or GetRootCAs must be set; the other will be ignored
//----------------------------------------------------------------------------------------------
// If field RootCACerts is set, field GetRootCAs will be ignored. The client will use RootCACerts
// every time when verifying the peer certificates, without performing root certificate reloading.
RootCACerts *x509.CertPool
// If GetRootCAs is set and RootCACerts is nil, the client will invoke this function every time
// asked to check certificates sent from the server when a new connection is established.
// This is known as root CA certificate reloading.
GetRootCAs func(rawConn net.Conn, rawCerts [][]byte) (*x509.CertPool, error)
// VerifyPeer is a custom server authorization checking after certificate signature check.
// If this is set, we will replace the hostname check with this customized authorization check.
// If this is nil, we fall back to typical hostname check.
VerifyPeer CustomVerificationFunc
// ServerNameOverride is for testing only. If set to a non-empty string,
// it will override the virtual host name of authority (e.g. :authority header field) in requests.
ServerNameOverride string
RootCertificateOptions
}

// ServerOptions contains all the fields and functions needed to be filled by the client.
Expand All @@ -93,30 +99,18 @@ type ServerOptions struct {
// function every time asked to present certificates to the client when a new connection is
// established. This is known as peer certificate reloading.
GetCertificate func(*tls.ClientHelloInfo) (*tls.Certificate, error)
// ---------------General Rules For Root Certificate Setting------------------------------------
// RootCACerts or GetRootCAs indicates the the certificates trusted by the server side. The rules
// for setting these two fields are:
// If requiring mutual authentication on server side:
// Either RootCACerts or GetRootCAs must be set; the other will be ignored
// Otherwise:
// Nothing needed(the two fields will be ignored)
//----------------------------------------------------------------------------------------------
// If field RootCACerts is set, field GetRootCAs will be ignored. The server will use RootCACerts
// every time when verifying the peer certificates, without performing root certificate reloading.
RootCACerts *x509.CertPool
// If GetRootCAs is set and RootCACerts is nil, the server will invoke this function every time
// asked to check certificates sent from the client when a new connection is established.
// This is known as root CA certificate reloading.
GetRootCAs func(rawConn net.Conn, rawCerts [][]byte) (*x509.CertPool, error)
// If the server want the client to send certificates to prove its identity.
MutualAuth bool
RootCertificateOptions
// If the server want the client to send certificates.
RequireClientCert bool
}

func (o *ClientOptions) config() (*tls.Config, error) {
if o.RootCACerts == nil && o.GetRootCAs == nil && o.VerifyPeer == nil {
return nil, fmt.Errorf(
"client needs to provide root CA certs, or a custom verification function")
}
// We have to set InsecureSkipVerify to true to skip the default checks and use the
// verification function we built from buildVerifyFunc.
config := &tls.Config{
ServerName: o.ServerNameOverride,
Certificates: o.Certificates,
Expand All @@ -131,12 +125,16 @@ func (o *ServerOptions) config() (*tls.Config, error) {
if o.Certificates == nil && o.GetCertificate == nil {
return nil, fmt.Errorf("either Certificates or GetCertificate must be specified")
}
if o.MutualAuth && o.GetRootCAs == nil && o.RootCACerts == nil {
return nil, fmt.Errorf("server needs to provide root CA certs if using mutual TLS")
if o.RequireClientCert && o.GetRootCAs == nil && o.RootCACerts == nil {
return nil, fmt.Errorf("server needs to provide root CA certs if requiring client cert")
}
clientAuth := tls.NoClientCert
if o.MutualAuth {
if o.RequireClientCert {
// We fall back to normal config settings if users don't need to reload root certificates.
// If using RequireAndVerifyClientCert, the underlying stack would use the default
// checking and ignore the verification function we built from buildVerifyFunc.
// If using RequireAnyClientCert, the code would skip all the checks and use the
// function from buildVerifyFunc.
if o.RootCACerts != nil {
clientAuth = tls.RequireAndVerifyClientCert
} else {
Expand Down
26 changes: 17 additions & 9 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,11 +418,13 @@ func TestClientServerHandshake(t *testing.T) {
}
// Start a server using ServerOptions in another goroutine.
serverOptions := &ServerOptions{
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
MutualAuth: test.serverMutualTLS,
Certificates: test.serverCert,
GetCertificate: test.serverGetCert,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.serverRoot,
GetRootCAs: test.serverGetRoot,
},
RequireClientCert: test.serverMutualTLS,
}
go func(done chan credentials.AuthInfo, lis net.Listener, serverOptions *ServerOptions) {
serverRawConn, err := lis.Accept()
Expand Down Expand Up @@ -455,9 +457,11 @@ func TestClientServerHandshake(t *testing.T) {
clientOptions := &ClientOptions{
Certificates: test.clientCert,
GetClientCertificate: test.clientGetClientCert,
RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot,
VerifyPeer: test.clientVerifyFunc,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: test.clientRoot,
GetRootCAs: test.clientGetRoot,
},
}
clientTLS, newClientErr := NewClientTLS(clientOptions)
if newClientErr != nil && test.clientExpectCreateError {
Expand Down Expand Up @@ -539,7 +543,9 @@ func TestAdvancedTLSOverrideServerName(t *testing.T) {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCACerts: clientTrustPool,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
}
c, err := NewClientTLS(clientOptions)
Expand All @@ -559,7 +565,9 @@ func TestTLSClone(t *testing.T) {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
}
clientOptions := &ClientOptions{
RootCACerts: clientTrustPool,
RootCertificateOptions: RootCertificateOptions{
RootCACerts: clientTrustPool,
},
ServerNameOverride: expectedServerName,
}
c, err := NewClientTLS(clientOptions)
Expand Down

0 comments on commit 850dad9

Please sign in to comment.