Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions lib/vnet/ssh_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package vnet
import (
"context"
"crypto/rand"
"sync"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand All @@ -33,10 +34,8 @@ import (
// and the root cluster proxy terminated with the SSH key in the
// [ssh.ClientConfig], and then the key forwarded via this agent will be used
// to terminate the final SSH connection to the target node.
//
// It is not safe for concurrent use, setSessionKey must only be called before
// the agent will actually be used.
type sshAgent struct {
mu sync.Mutex
signer ssh.Signer
}

Expand All @@ -49,6 +48,8 @@ func newSSHAgent() *sshAgent {
// agent must be passed to [proxy.Client.DialHost] before the session SSH
// signer has been created.
func (a *sshAgent) setSessionKey(signer ssh.Signer) error {
a.mu.Lock()
defer a.mu.Unlock()
if a.signer != nil {
return trace.Errorf("sshAgent.setSessionKey must be called at most once (this is a bug)")
}
Expand All @@ -59,6 +60,8 @@ func (a *sshAgent) setSessionKey(signer ssh.Signer) error {
// List implements [agent.ExtendedAgent.List], it returns a single key if it
// has been set by setSessionKey.
func (a *sshAgent) List() ([]*agent.Key, error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.signer == nil {
return nil, nil
}
Expand All @@ -72,6 +75,8 @@ func (a *sshAgent) List() ([]*agent.Key, error) {
// List implements [agent.ExtendedAgent.Signers], it returns a single key if it
// has been set by setSessionKey.
func (a *sshAgent) Signers() ([]ssh.Signer, error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.signer == nil {
return nil, nil
}
Expand All @@ -88,6 +93,8 @@ func (a *sshAgent) Sign(key ssh.PublicKey, data []byte) (*ssh.Signature, error)
// SSH signature with a.signer if it has been set and matches the requested
// key.
func (a *sshAgent) SignWithFlags(key ssh.PublicKey, data []byte, flags agent.SignatureFlags) (*ssh.Signature, error) {
a.mu.Lock()
defer a.mu.Unlock()
if a.signer == nil {
return nil, trace.Errorf("VNet SSH agent has no signer")
}
Expand Down
54 changes: 28 additions & 26 deletions lib/vnet/vnet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1158,8 +1158,7 @@ func testWithAlgorithmSuite(t *testing.T, suite types.SignatureAlgorithmSuite) {

// TestSSH tests basic VNet SSH functionality.
func TestSSH(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
ctx := t.Context()
clock := clockwork.NewRealClock()
homePath := t.TempDir()

Expand Down Expand Up @@ -1343,42 +1342,44 @@ func TestSSH(t *testing.T) {
t.Run(fmt.Sprintf("%s@%s:%d", tc.sshUser, tc.dialAddr, tc.dialPort), func(t *testing.T) {
t.Parallel()

lookupCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
// The DNS lookup for *.<cluster-name> should resolve to an IP in
// the expected CIDR range for the cluster.
resolvedAddrs, err := p.lookupHost(lookupCtx, tc.dialAddr)
if tc.expectLookupToFail {
// In these cases the DNS lookup is expected to fail, just run the DNS lookup.
ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
_, err := p.lookupHost(ctx, tc.dialAddr)
require.Error(t, err)
return
}
require.NoError(t, err)

_, expectNet, err := net.ParseCIDR(tc.expectCIDR)
require.NoError(t, err)

for _, resolvedAddr := range resolvedAddrs {
resolvedIP := net.ParseIP(resolvedAddr)
// The query may have resolved to a v4 or v6 address or both,
// either way the 4-byte suffix should be a valid IPv4 address
// in the expected CIDR range.
resolvedIPSuffix := resolvedIP[len(resolvedIP)-4:]
assert.True(t, expectNet.Contains(resolvedIPSuffix),
"expected CIDR range %s does not include resolved IP %s", expectNet, resolvedIPSuffix)
}

// TCP dial the target address, it should fail if the node doesn't
// exist.
dialCtx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
conn, err := p.dialHost(dialCtx, tc.dialAddr, tc.dialPort)
if tc.expectDialToFail {
// In these cases the DNS lookup should succeed but then the
// TCP dial should fail, do each separately to make sure we
// catch the error at the right step.
resolvedAddrs, err := p.lookupHost(ctx, tc.dialAddr)
require.NoError(t, err)
require.NotEmpty(t, resolvedAddrs)

ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond)
defer cancel()
_, err = p.dialHost(ctx, resolvedAddrs[0], tc.dialPort)
require.Error(t, err)
return
}

conn, err := p.dialHost(ctx, tc.dialAddr, tc.dialPort)
require.NoError(t, err)
defer conn.Close()

// The DNS query may have resolved to a v4 or v6 address, either
// way the 4-byte suffix should be a valid IPv4 address in the
// expected CIDR range.
resolvedIP := conn.RemoteAddr().(*net.TCPAddr).IP
resolvedIPSuffix := resolvedIP[len(resolvedIP)-4:]
_, expectNet, err := net.ParseCIDR(tc.expectCIDR)
require.NoError(t, err)
assert.True(t, expectNet.Contains(resolvedIPSuffix),
"expected CIDR range %s does not include resolved IP %s", expectNet, resolvedIPSuffix)

// Initiate an SSH connection to the target. At this point the
// handshake should complete successfully as long as the right keys
// are used, but the SSH connection will be immediately closed by
Expand Down Expand Up @@ -1414,6 +1415,7 @@ func TestSSH(t *testing.T) {

// Test that a fresh SSH host cert is used on each connection.
t.Run("ephemeral certs", func(t *testing.T) {
t.Parallel()
// Set up the SSH client config to capture the host certs it sees.
var checkedHostCerts []*ssh.Certificate
clientConfig := &ssh.ClientConfig{
Expand Down
Loading