Skip to content

Commit

Permalink
Add some more test cases; fix nits
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhenLian committed Nov 22, 2019
1 parent 6e46b5f commit 781f1c0
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 8 deletions.
2 changes: 2 additions & 0 deletions security/advancedtls/advancedtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ func buildVerifyFunc(c *advancedTLSCreds,
return verifyFunc
}

// TODO(ZhenLian): The code below are duplicates with gRPC-Go under
// TODO(ZhenLian): credentials/internal. Consider refactoring in the future.
const alpnProtoStrH2 = "h2"

func appendH2ToNextProtos(ps []string) []string {
Expand Down
139 changes: 131 additions & 8 deletions security/advancedtls/advancedtls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
)

func TestClientServerHandshake(t *testing.T) {
// ------------------Load Client Trust Cert and Server Peer Cert-------------------
// ------------------Load Client Trust Cert and Peer Cert-------------------
clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem")
if err != nil {
t.Fatalf("Client is unable to load trust certs. Error: %v", err)
Expand All @@ -49,22 +49,24 @@ func TestClientServerHandshake(t *testing.T) {
return nil
}
return fmt.Errorf("custom verification function failed")

}
serverPeerCert, err := tls.LoadX509KeyPair("testdata/server_cert_1.pem",
"testdata/server_key_1.pem")
// ------------------Load Server Trust Cert and Client Peer Cert-------------------
clientPeerCert, err := tls.LoadX509KeyPair("testdata/client_cert_1.pem",
"testdata/client_key_1.pem")
if err != nil {
t.Fatalf("Client is unable to parse peer certificates. Error: %v", err)
}
// ------------------Load Server Trust Cert and Peer Cert-------------------
serverTrustPool, err := readTrustCert("testdata/server_trust_cert_1.pem")
if err != nil {
t.Fatalf("Server is unable to load trust certs. Error: %v", err)
}
getRootCAsForServer := func(rawConn net.Conn, rawCerts [][]byte) (*x509.CertPool, error) {
return serverTrustPool, nil
}
clientPeerCert, err := tls.LoadX509KeyPair("testdata/client_cert_1.pem",
"testdata/client_key_1.pem")
serverPeerCert, err := tls.LoadX509KeyPair("testdata/server_cert_1.pem",
"testdata/server_key_1.pem")
if err != nil {
t.Fatalf("Server is unable to parse certificates. Error: %v", err)
t.Fatalf("Server is unable to parse peer certificates. Error: %v", err)
}
getRootCAsForServerBad := func(rawConn net.Conn, rawCerts [][]byte) (*x509.CertPool, error) {
return nil, fmt.Errorf("bad root certificate reloading")
Expand Down Expand Up @@ -265,6 +267,125 @@ func TestClientServerHandshake(t *testing.T) {
getRootCAsForServerBad,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientGetClientCert
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: success
{
"Client_reload_both_certs_verifyFunc_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFunc,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
false,
},
// Client: set everything but with the wrong peer cert not trusted by server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side returns failure because of
// certificate mismatch
{
"Client_wrong_peer_cert_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFunc,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set everything but with the wrong trust cert not trusting server
// Server: set serverGetRoot and serverGetCert with mutual TLS on
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_wrong_trust_cert_Server_reload_both_certs_mutualTLS",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForServer,
verifyFunc,
false,
true,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong peer cert not trusted by client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_reload_both_certs_verifyFunc_Server_wrong_peer_cert",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFunc,
false,
false,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForServer,
true,
},
// Client: set clientGetRoot, clientVerifyFunc and clientCert
// Server: set everything but with the wrong trust cert not trusting client
// Expected Behavior: server side and client side return failure due to
// certificate mismatch and handshake failure
{
"Client_reload_both_certs_verifyFunc_Server_wrong_trust_cert",
nil,
func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
return &clientPeerCert, nil
},
nil,
getRootCAsForClient,
verifyFunc,
false,
true,
true,
nil,
func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
return &serverPeerCert, nil
},
nil,
getRootCAsForClient,
true,
},
} {
test := test
t.Run(test.desc, func(t *testing.T) {
Expand Down Expand Up @@ -368,6 +489,7 @@ func readTrustCert(fileName string) (*x509.CertPool, error) {
trustPool.AddCert(trustCert)
return trustPool, nil
}

func compare(a1, a2 credentials.AuthInfo) bool {
if a1.AuthType() != a2.AuthType() {
return false
Expand All @@ -387,6 +509,7 @@ func compare(a1, a2 credentials.AuthInfo) bool {
return false
}
}

func TestAdvancedTLSOverrideServerName(t *testing.T) {
expectedServerName := "server.name"
clientTrustPool, err := readTrustCert("testdata/client_trust_cert_1.pem")
Expand Down

0 comments on commit 781f1c0

Please sign in to comment.