diff --git a/lib/vnet/ssh_agent.go b/lib/vnet/ssh_agent.go index 314db3373938e..1db229c1e82fa 100644 --- a/lib/vnet/ssh_agent.go +++ b/lib/vnet/ssh_agent.go @@ -19,6 +19,7 @@ package vnet import ( "context" "crypto/rand" + "sync" "github.com/gravitational/trace" "golang.org/x/crypto/ssh" @@ -33,10 +34,8 @@ import ( // and the root cluster proxy terminated with the SSH key in the // [ssh.ClientConfig], and then the key forwarded via this agent will be used // to terminate the final SSH connection to the target node. -// -// It is not safe for concurrent use, setSessionKey must only be called before -// the agent will actually be used. type sshAgent struct { + mu sync.Mutex signer ssh.Signer } @@ -49,6 +48,8 @@ func newSSHAgent() *sshAgent { // agent must be passed to [proxy.Client.DialHost] before the session SSH // signer has been created. func (a *sshAgent) setSessionKey(signer ssh.Signer) error { + a.mu.Lock() + defer a.mu.Unlock() if a.signer != nil { return trace.Errorf("sshAgent.setSessionKey must be called at most once (this is a bug)") } @@ -59,6 +60,8 @@ func (a *sshAgent) setSessionKey(signer ssh.Signer) error { // List implements [agent.ExtendedAgent.List], it returns a single key if it // has been set by setSessionKey. func (a *sshAgent) List() ([]*agent.Key, error) { + a.mu.Lock() + defer a.mu.Unlock() if a.signer == nil { return nil, nil } @@ -72,6 +75,8 @@ func (a *sshAgent) List() ([]*agent.Key, error) { // List implements [agent.ExtendedAgent.Signers], it returns a single key if it // has been set by setSessionKey. func (a *sshAgent) Signers() ([]ssh.Signer, error) { + a.mu.Lock() + defer a.mu.Unlock() if a.signer == nil { return nil, nil } @@ -88,6 +93,8 @@ func (a *sshAgent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error) // SSH signature with a.signer if it has been set and matches the requested // key. func (a *sshAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) { + a.mu.Lock() + defer a.mu.Unlock() if a.signer == nil { return nil, trace.Errorf("VNet SSH agent has no signer") } diff --git a/lib/vnet/vnet_test.go b/lib/vnet/vnet_test.go index 0e3ef9c42636a..07f6075aed119 100644 --- a/lib/vnet/vnet_test.go +++ b/lib/vnet/vnet_test.go @@ -1158,8 +1158,7 @@ func testWithAlgorithmSuite(t *testing.T, suite types.SignatureAlgorithmSuite) { // TestSSH tests basic VNet SSH functionality. func TestSSH(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) + ctx := t.Context() clock := clockwork.NewRealClock() homePath := t.TempDir() @@ -1343,42 +1342,44 @@ func TestSSH(t *testing.T) { t.Run(fmt.Sprintf("%s@%s:%d", tc.sshUser, tc.dialAddr, tc.dialPort), func(t *testing.T) { t.Parallel() - lookupCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) - defer cancel() - // The DNS lookup for *. should resolve to an IP in - // the expected CIDR range for the cluster. - resolvedAddrs, err := p.lookupHost(lookupCtx, tc.dialAddr) if tc.expectLookupToFail { + // In these cases the DNS lookup is expected to fail, just run the DNS lookup. + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + _, err := p.lookupHost(ctx, tc.dialAddr) require.Error(t, err) return } - require.NoError(t, err) - _, expectNet, err := net.ParseCIDR(tc.expectCIDR) - require.NoError(t, err) - - for _, resolvedAddr := range resolvedAddrs { - resolvedIP := net.ParseIP(resolvedAddr) - // The query may have resolved to a v4 or v6 address or both, - // either way the 4-byte suffix should be a valid IPv4 address - // in the expected CIDR range. - resolvedIPSuffix := resolvedIP[len(resolvedIP)-4:] - assert.True(t, expectNet.Contains(resolvedIPSuffix), - "expected CIDR range %s does not include resolved IP %s", expectNet, resolvedIPSuffix) - } - - // TCP dial the target address, it should fail if the node doesn't - // exist. - dialCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) - defer cancel() - conn, err := p.dialHost(dialCtx, tc.dialAddr, tc.dialPort) if tc.expectDialToFail { + // In these cases the DNS lookup should succeed but then the + // TCP dial should fail, do each separately to make sure we + // catch the error at the right step. + resolvedAddrs, err := p.lookupHost(ctx, tc.dialAddr) + require.NoError(t, err) + require.NotEmpty(t, resolvedAddrs) + + ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) + defer cancel() + _, err = p.dialHost(ctx, resolvedAddrs[0], tc.dialPort) require.Error(t, err) return } + + conn, err := p.dialHost(ctx, tc.dialAddr, tc.dialPort) require.NoError(t, err) defer conn.Close() + // The DNS query may have resolved to a v4 or v6 address, either + // way the 4-byte suffix should be a valid IPv4 address in the + // expected CIDR range. + resolvedIP := conn.RemoteAddr().(*net.TCPAddr).IP + resolvedIPSuffix := resolvedIP[len(resolvedIP)-4:] + _, expectNet, err := net.ParseCIDR(tc.expectCIDR) + require.NoError(t, err) + assert.True(t, expectNet.Contains(resolvedIPSuffix), + "expected CIDR range %s does not include resolved IP %s", expectNet, resolvedIPSuffix) + // Initiate an SSH connection to the target. At this point the // handshake should complete successfully as long as the right keys // are used, but the SSH connection will be immediately closed by @@ -1414,6 +1415,7 @@ func TestSSH(t *testing.T) { // Test that a fresh SSH host cert is used on each connection. t.Run("ephemeral certs", func(t *testing.T) { + t.Parallel() // Set up the SSH client config to capture the host certs it sees. var checkedHostCerts []*ssh.Certificate clientConfig := &ssh.ClientConfig{