Skip to content
20 changes: 0 additions & 20 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ import (
"github.com/gravitational/teleport/lib/sshutils/sftp"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/agentconn"
logutils "github.com/gravitational/teleport/lib/utils/log"
)

Expand Down Expand Up @@ -4915,25 +4914,6 @@ func loopbackPool(proxyAddr string) *x509.CertPool {
return certPool
}

// connectToSSHAgent connects to the system SSH agent and returns an agent.Agent.
func connectToSSHAgent() agent.ExtendedAgent {
ctx := context.Background()
logger := log.With(teleport.ComponentKey, teleport.ComponentKeyAgent)

socketPath := os.Getenv(teleport.SSHAuthSock)
conn, err := agentconn.Dial(socketPath)
if err != nil {
logger.WarnContext(ctx, "Unable to connect to SSH agent on socket",
"socket_path", socketPath,
"error", err,
)
return nil
}

logger.InfoContext(ctx, "Connected to the system agent", "socket_path", socketPath)
return agent.NewClient(conn)
}

// Username returns the current user's username
func Username() (string, error) {
u, err := user.Current()
Expand Down
9 changes: 8 additions & 1 deletion lib/client/keyagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
"github.com/gravitational/teleport/api/utils/prompt"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/tlsca"
logutils "github.com/gravitational/teleport/lib/utils/log"
)
Expand Down Expand Up @@ -133,7 +134,13 @@ func NewLocalAgent(conf LocalAgentConfig) (a *LocalKeyAgent, err error) {
}

if shouldAddKeysToAgent(conf.KeysOption) {
a.systemAgent = connectToSSHAgent()
a.log.DebugContext(context.Background(), "Connecting to the system agent")
systemAgent, err := sshagent.NewSystemAgentClient()
if err != nil {
a.log.WarnContext(context.Background(), "Unable to connect to system agent", "error", err)
} else {
a.systemAgent = systemAgent
}
} else {
log.DebugContext(context.Background(), "Skipping connection to the local ssh-agent.")

Expand Down
17 changes: 9 additions & 8 deletions lib/client/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ import (
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/x11"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -248,11 +249,11 @@ func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback
// if agent forwarding was requested (and we have a agent to forward),
// forward the agent to endpoint.
tc := ns.nodeClient.TC
targetAgent := selectKeyAgent(tc)
targetAgent := selectKeyAgent(ctx, tc)

if targetAgent != nil {
log.DebugContext(ctx, "Forwarding Selected Key Agent")
err = agent.ForwardToAgent(ns.nodeClient.Client.Client, targetAgent)
err = sshagent.ServeChannelRequests(ctx, ns.nodeClient.Client.Client, targetAgent)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -267,16 +268,16 @@ func (ns *NodeSession) createServerSession(ctx context.Context, chanReqCallback

// selectKeyAgent picks the appropriate key agent for forwarding to the
// server, if any.
func selectKeyAgent(tc *TeleportClient) agent.ExtendedAgent {
func selectKeyAgent(ctx context.Context, tc *TeleportClient) sshagent.ClientGetter {
switch tc.ForwardAgent {
case ForwardAgentYes:
log.DebugContext(context.Background(), "Selecting system key agent")
return connectToSSHAgent()
log.DebugContext(ctx, "Selecting system key agent")
return sshagent.NewSystemAgentClient
case ForwardAgentLocal:
log.DebugContext(context.Background(), "Selecting local Teleport key agent")
return tc.localAgent.ExtendedAgent
log.DebugContext(ctx, "Selecting local Teleport key agent")
return sshagent.NewStaticClientGetter(tc.localAgent.ExtendedAgent)
default:
log.DebugContext(context.Background(), "No Key Agent selected")
log.DebugContext(ctx, "No Key Agent selected")
return nil
}
}
Expand Down
136 changes: 136 additions & 0 deletions lib/sshagent/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package sshagent

import (
"context"
"errors"
"io"
"log/slog"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

// Client extends the [agent.ExtendedAgent] interface with an [io.Closer].
type Client interface {
agent.ExtendedAgent
io.Closer
}

// ClientGetter is a function used to get a new agent client.
type ClientGetter = func() (Client, error)

type client struct {
agent.ExtendedAgent
conn io.Closer
}

// NewClient creates a new SSH Agent client with an open connection using
// the provided connection function. The resulting connection can be any
// [io.ReadWriteCloser], such as a [net.Conn] or [ssh.Channel].
func NewClient(connect func() (io.ReadWriteCloser, error)) (Client, error) {
conn, err := connect()
if err != nil {
return nil, trace.Wrap(err)
}

return &client{
ExtendedAgent: agent.NewClient(conn),
conn: conn,
}, nil
}

// NewSystemAgentClient creates a new SSH Agent client with an open connection
// to the system agent, advertised by SSH_AUTH_SOCK or other system parameters.
func NewSystemAgentClient() (Client, error) {
return NewClient(DialSystemAgent)
}

// NewStaticClient creates a new SSH Agent client for the given static agent.
func NewStaticClient(agentClient agent.ExtendedAgent) Client {
return &client{
ExtendedAgent: agentClient,
}
}

// NewStaticClientGetter returns a [ClientGetter] for a static agent client.
func NewStaticClientGetter(agentClient agent.ExtendedAgent) ClientGetter {
return func() (Client, error) {
return &client{
ExtendedAgent: agentClient,
}, nil
}
}

// Close the agent client and prevent further requests.
func (c *client) Close() error {
if c.conn == nil {
return nil
}
err := c.conn.Close()
return trace.Wrap(err)
}

const channelType = "auth-agent@openssh.com"

// ServeChannelRequests routes agent channel requests to a new agent
// connection retrieved from the provided getter.
//
// This method differs from [agent.ForwardToAgent] in that each agent
// forwarding channel is handled with a new connection to the forward
// agent, rather than sharing a single long-lived connection.
//
// Specifically, this is necessary for Windows' named pipe ssh agent
// implementation, as the named pipe connection can be disrupted after
// signature requests. This issue may be resolved directly by the
// [agent] library once https://github.com/golang/go/issues/61383
// is addressed.
func ServeChannelRequests(ctx context.Context, client *ssh.Client, getForwardAgent ClientGetter) error {
channels := client.HandleChannelOpen(channelType)
if channels == nil {
return errors.New("agent forwarding channel already open")
}

go func() {
for ch := range channels {
channel, reqs, err := ch.Accept()
if err != nil {
continue
}

go ssh.DiscardRequests(reqs)

forwardAgent, err := getForwardAgent()
if err != nil {
_ = channel.Close()
slog.ErrorContext(ctx, "failed to connect to forwarded agent", "err", err)
continue
}

go func() {
defer channel.Close()
defer forwardAgent.Close()
if err := agent.ServeAgent(forwardAgent, channel); err != nil && !errors.Is(err, io.EOF) {
slog.ErrorContext(ctx, "unexpected error serving forwarded agent", "err", err)
}
}()
}
}()
return nil
}
133 changes: 133 additions & 0 deletions lib/sshagent/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
// Teleport
// Copyright (C) 2025 Gravitational, Inc.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package sshagent_test

import (
"context"
"io"
"net"
"path/filepath"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh/agent"

"github.com/gravitational/teleport/lib/sshagent"
"github.com/gravitational/teleport/lib/utils"
)

func TestSSHAgentClient(t *testing.T) {
keyring, ok := agent.NewKeyring().(agent.ExtendedAgent)
require.True(t, ok)

agentDir := t.TempDir()
agentPath := filepath.Join(agentDir, "agent.sock")
startAgentServer := func() (stop func()) {
l, err := net.Listen("unix", agentPath)
require.NoError(t, err)

// create a context to close existing connections on server shutdown.
serveCtx, serveCancel := context.WithCancel(t.Context())

// Track open connections.
var connWg sync.WaitGroup

connWg.Add(1)
go func() {
defer connWg.Done()
for {
conn, err := l.Accept()
if err != nil {
assert.True(t, utils.IsUseOfClosedNetworkError(err))
return
}

closeConn := func() {
conn.Close()
connWg.Done()
}

// Close the connection early if the server is stopped.
connNotClosed := context.AfterFunc(serveCtx, closeConn)

connWg.Add(1)
go func() {
defer func() {
if connNotClosed() {
closeConn()
}
}()
agent.ServeAgent(keyring, conn)
}()
}
}()

// Close the listener, stop serving, and wait for all open client
// connections to close.
stopServer := func() {
l.Close()
serveCancel()
connWg.Wait()
}

return stopServer
}

stopServer := startAgentServer()
t.Cleanup(stopServer)

clientConnect := func() (io.ReadWriteCloser, error) {
return net.Dial("unix", agentPath)
}
clientGetter := func() (sshagent.Client, error) {
return sshagent.NewClient(clientConnect)
}

// Get a new agent client and make successful requests.
agentClient, err := clientGetter()
require.NoError(t, err)
_, err = agentClient.List()
require.NoError(t, err)

// Close the server and all client connections, client should fail.
stopServer()
_, err = agentClient.List()
// TODO(Joerger): Ideally we would check for the error (io.EOF),
// but the agent library isn't properly wrapping its errors.
require.Error(t, err)

// Getting a new client should fail.
_, err = clientGetter()
require.Error(t, err)

// Re-open the server. Get a new agent client connection.
stopServer = startAgentServer()
t.Cleanup(stopServer)

agentClient, err = clientGetter()
require.NoError(t, err)
_, err = agentClient.List()
require.NoError(t, err)

// Close the client, it should return an error when receiving requests.
err = agentClient.Close()
require.NoError(t, err)
_, err = agentClient.List()
require.Error(t, err)
}
Loading
Loading