Skip to content

Commit

Permalink
Merge pull request #193 from hashicorp/fix-server-automtls
Browse files Browse the repository at this point in the history
automtls: fix bidirectional communication when AutoMTLS is enabled
  • Loading branch information
fairclothjm authored May 3, 2022
2 parents e4102ee + afb1659 commit 5dee41c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 4 deletions.
5 changes: 4 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,8 @@ func (c *Client) Start() (addr net.Addr, err error) {

c.config.TLSConfig = &tls.Config{
Certificates: []tls.Certificate{cert},
ClientAuth: tls.RequireAndVerifyClientCert,
MinVersion: tls.VersionTLS12,
ServerName: "localhost",
}
}
Expand Down Expand Up @@ -776,7 +778,7 @@ func (c *Client) Start() (addr net.Addr, err error) {
}

// loadServerCert is used by AutoMTLS to read an x.509 cert returned by the
// server, and load it as the RootCA for the client TLSConfig.
// server, and load it as the RootCA and ClientCA for the client TLSConfig.
func (c *Client) loadServerCert(cert string) error {
certPool := x509.NewCertPool()

Expand All @@ -793,6 +795,7 @@ func (c *Client) loadServerCert(cert string) error {
certPool.AddCert(x509Cert)

c.config.TLSConfig.RootCAs = certPool
c.config.TLSConfig.ClientCAs = certPool
return nil
}

Expand Down
6 changes: 4 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,13 +304,13 @@ func Serve(opts *ServeConfig) {

certPEM, keyPEM, err := generateCert()
if err != nil {
logger.Error("failed to generate client certificate", "error", err)
logger.Error("failed to generate server certificate", "error", err)
panic(err)
}

cert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
logger.Error("failed to parse client certificate", "error", err)
logger.Error("failed to parse server certificate", "error", err)
panic(err)
}

Expand All @@ -319,6 +319,8 @@ func Serve(opts *ServeConfig) {
ClientAuth: tls.RequireAndVerifyClientCert,
ClientCAs: clientCertPool,
MinVersion: tls.VersionTLS12,
RootCAs: clientCertPool,
ServerName: "localhost",
}

// We send back the raw leaf cert data for the client rather than the
Expand Down
70 changes: 69 additions & 1 deletion server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,75 @@ func TestServer_testMode(t *testing.T) {
t.Logf("HELLO")
}

func TestServer_testMode_AutoMTLS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

closeCh := make(chan struct{})
go Serve(&ServeConfig{
HandshakeConfig: testVersionedHandshake,
VersionedPlugins: map[int]PluginSet{
2: testGRPCPluginMap,
},
GRPCServer: DefaultGRPCServer,
Logger: hclog.NewNullLogger(),
Test: &ServeTestConfig{
Context: ctx,
ReattachConfigCh: nil,
CloseCh: closeCh,
},
})

// Connect!
process := helperProcess("test-mtls")
c := NewClient(&ClientConfig{
Cmd: process,
HandshakeConfig: testVersionedHandshake,
VersionedPlugins: map[int]PluginSet{
2: testGRPCPluginMap,
},
AllowedProtocols: []Protocol{ProtocolGRPC},
AutoMTLS: true,
})
client, err := c.Client()
if err != nil {
t.Fatalf("err: %s", err)
}

// Pinging should work
if err := client.Ping(); err != nil {
t.Fatalf("should not err: %s", err)
}

// Grab the impl
raw, err := client.Dispense("test")
if err != nil {
t.Fatalf("err should be nil, got %s", err)
}

tester, ok := raw.(testInterface)
if !ok {
t.Fatalf("bad: %#v", raw)
}

n := tester.Double(3)
if n != 6 {
t.Fatal("invalid response", n)
}

// ensure we can make use of bidirectional communication with AutoMTLS
// enabled
err = tester.Bidirectional()
if err != nil {
t.Fatal("invalid response", err)
}

c.Kill()
// Canceling should cause an exit
cancel()
<-closeCh
}

func TestRmListener_impl(t *testing.T) {
var _ net.Listener = new(rmListener)
}
Expand Down Expand Up @@ -145,7 +214,6 @@ func TestProtocolSelection_no_server(t *testing.T) {
if protocol != ProtocolNetRPC {
t.Fatalf("bad protocol %s", protocol)
}

}

func TestServer_testStdLogger(t *testing.T) {
Expand Down

0 comments on commit 5dee41c

Please sign in to comment.