Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ import (
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/agentconn"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

Expand Down Expand Up @@ -4733,19 +4732,6 @@ func loopbackPool(proxyAddr string) *x509.CertPool {
return certPool
}

// connectToSSHAgent connects to the system SSH agent and returns an agent.Agent.
func connectToSSHAgent() agent.ExtendedAgent {
socketPath := os.Getenv(teleport.SSHAuthSock)
conn, err := agentconn.Dial(socketPath)
if err != nil {
log.Warnf("[KEY AGENT] Unable to connect to SSH agent on socket %q: %v", socketPath, err)
return nil
}

log.Infof("[KEY AGENT] Connected to the system agent: %q", socketPath)
return agent.NewClient(conn)
}

// Username returns the current user's username
func Username() (string, error) {
u, err := apiutils.CurrentUser()
Expand Down
9 changes: 8 additions & 1 deletion lib/client/keyagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/api/utils/prompt"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/tlsca"
)

Expand Down Expand Up @@ -134,7 +135,13 @@ func NewLocalAgent(conf LocalAgentConfig) (a *LocalKeyAgent, err error) {
}

if shouldAddKeysToAgent(conf.KeysOption) {
a.systemAgent = connectToSSHAgent()
a.log.Debug("Connecting to the system agent")
systemAgent, err := sshagent.NewSystemAgentClient()
if err != nil {
a.log.Warnf("Unable to connect to system agent: %v", err)
} else {
a.systemAgent = systemAgent
}
} else {
log.Debug("Skipping connection to the local ssh-agent.")

Expand Down
11 changes: 6 additions & 5 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -253,11 +254,11 @@ func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback
// if agent forwarding was requested (and we have a agent to forward),
// forward the agent to endpoint.
tc := ns.nodeClient.TC
targetAgent := selectKeyAgent(tc)
targetAgent := selectKeyAgent(ctx, tc)

if targetAgent != nil {
log.Debugf("Forwarding Selected Key Agent")
err = agent.ForwardToAgent(ns.nodeClient.Client.Client, targetAgent)
err = sshagent.ServeChannelRequests(ctx, ns.nodeClient.Client.Client, targetAgent)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -272,14 +273,14 @@ func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback

// selectKeyAgent picks the appropriate key agent for forwarding to the
// server, if any.
func selectKeyAgent(tc *TeleportClient) agent.ExtendedAgent {
func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGetter {
switch tc.ForwardAgent {
case ForwardAgentYes:
log.Debugf("Selecting system key agent.")
return connectToSSHAgent()
return sshagent.NewSystemAgentClient
case ForwardAgentLocal:
log.Debugf("Selecting local Teleport key agent.")
return tc.localAgent.ExtendedAgent
return sshagent.NewStaticClientGetter(tc.localAgent.ExtendedAgent)
default:
log.Debugf("No Key Agent selected.")
return nil
Expand Down
136 changes: 136 additions & 0 deletions lib/sshagent/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package sshagent

import (
"context"
"errors"
"io"
"log/slog"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// Client extends the [agent.ExtendedAgent] interface with an [io.Closer].
type Client interface {
agent.ExtendedAgent
io.Closer
}

// ClientGetter is a function used to get a new agent client.
type ClientGetter = func() (Client, error)

type client struct {
agent.ExtendedAgent
conn io.Closer
}

// NewClient creates a new SSH Agent client with an open connection using
// the provided connection function. The resulting connection can be any
// [io.ReadWriteCloser], such as a [net.Conn] or [ssh.Channel].
func NewClient(connect func() (io.ReadWriteCloser, error)) (Client, error) {
conn, err := connect()
if err != nil {
return nil, trace.Wrap(err)
}

return &client{
ExtendedAgent: agent.NewClient(conn),
conn: conn,
}, nil
}

// NewSystemAgentClient creates a new SSH Agent client with an open connection
// to the system agent, advertised by SSH_AUTH_SOCK or other system parameters.
func NewSystemAgentClient() (Client, error) {
return NewClient(DialSystemAgent)
}

// NewStaticClient creates a new SSH Agent client for the given static agent.
func NewStaticClient(agentClient agent.ExtendedAgent) Client {
return &client{
ExtendedAgent: agentClient,
}
}

// NewStaticClientGetter returns a [ClientGetter] for a static agent client.
func NewStaticClientGetter(agentClient agent.ExtendedAgent) ClientGetter {
return func() (Client, error) {
return &client{
ExtendedAgent: agentClient,
}, nil
}
}

// Close the agent client and prevent further requests.
func (c *client) Close() error {
if c.conn == nil {
return nil
}
err := c.conn.Close()
return trace.Wrap(err)
}

const channelType = "auth-agent@openssh.com"

// ServeChannelRequests routes agent channel requests to a new agent
// connection retrieved from the provided getter.
//
// This method differs from [agent.ForwardToAgent] in that each agent
// forwarding channel is handled with a new connection to the forward
// agent, rather than sharing a single long-lived connection.
//
// Specifically, this is necessary for Windows' named pipe ssh agent
// implementation, as the named pipe connection can be disrupted after
// signature requests. This issue may be resolved directly by the
// [agent] library once https://github.com/golang/go/issues/61383
// is addressed.
func ServeChannelRequests(ctx context.Context, client *ssh.Client, getForwardAgent ClientGetter) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent forwarding channel already open")
}

go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}

go ssh.DiscardRequests(reqs)

forwardAgent, err := getForwardAgent()
if err != nil {
_ = channel.Close()
slog.ErrorContext(ctx, "failed to connect to forwarded agent", "err", err)
continue
}

go func() {
defer channel.Close()
defer forwardAgent.Close()
if err := agent.ServeAgent(forwardAgent, channel); err != nil && !errors.Is(err, io.EOF) {
slog.ErrorContext(ctx, "unexpected error serving forwarded agent", "err", err)
}
}()
}
}()
return nil
}
133 changes: 133 additions & 0 deletions lib/sshagent/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package sshagent_test

import (
"context"
"io"
"net"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh/agent"

"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/utils"
)

func TestSSHAgentClient(t *testing.T) {
keyring, ok := agent.NewKeyring().(agent.ExtendedAgent)
require.True(t, ok)

agentDir := t.TempDir()
agentPath := filepath.Join(agentDir, "agent.sock")
startAgentServer := func() (stop func()) {
l, err := net.Listen("unix", agentPath)
require.NoError(t, err)

// create a context to close existing connections on server shutdown.
serveCtx, serveCancel := context.WithCancel(context.Background())

// Track open connections.
var connWg sync.WaitGroup

connWg.Add(1)
go func() {
defer connWg.Done()
for {
conn, err := l.Accept()
if err != nil {
assert.True(t, utils.IsUseOfClosedNetworkError(err))
return
}

closeConn := func() {
conn.Close()
connWg.Done()
}

// Close the connection early if the server is stopped.
connNotClosed := context.AfterFunc(serveCtx, closeConn)

connWg.Add(1)
go func() {
defer func() {
if connNotClosed() {
closeConn()
}
}()
agent.ServeAgent(keyring, conn)
}()
}
}()

// Close the listener, stop serving, and wait for all open client
// connections to close.
stopServer := func() {
l.Close()
serveCancel()
connWg.Wait()
}

return stopServer
}

stopServer := startAgentServer()
t.Cleanup(stopServer)

clientConnect := func() (io.ReadWriteCloser, error) {
return net.Dial("unix", agentPath)
}
clientGetter := func() (sshagent.Client, error) {
return sshagent.NewClient(clientConnect)
}

// Get a new agent client and make successful requests.
agentClient, err := clientGetter()
require.NoError(t, err)
_, err = agentClient.List()
require.NoError(t, err)

// Close the server and all client connections, client should fail.
stopServer()
_, err = agentClient.List()
// TODO(Joerger): Ideally we would check for the error (io.EOF),
// but the agent library isn't properly wrapping its errors.
require.Error(t, err)

// Getting a new client should fail.
_, err = clientGetter()
require.Error(t, err)

// Re-open the server. Get a new agent client connection.
stopServer = startAgentServer()
t.Cleanup(stopServer)

agentClient, err = clientGetter()
require.NoError(t, err)
_, err = agentClient.List()
require.NoError(t, err)

// Close the client, it should return an error when receiving requests.
err = agentClient.Close()
require.NoError(t, err)
_, err = agentClient.List()
require.Error(t, err)
}
Loading
Loading