diff --git a/integration/integration_test.go b/integration/integration_test.go index 9084d9f64b233..b44ad9d3b0a5a 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -7557,13 +7557,40 @@ func testAgentlessConnection(t *testing.T, suite *integrationTestSuite) { // so closing it here will result in io.EOF if the test passes _ = session.Close() }) - require.NoError(t, agent.ForwardToAgent(sshClient, tc.LocalAgent())) + + // this is essentially what agent.ForwardToAgent does, but we're + // doing it manually so can take ownership of the opened SSH channel + // and check that it's closed correctly + channels := sshClient.HandleChannelOpen("auth-agent@openssh.com") + require.NotNil(t, channels) + + doneServing := make(chan error) + go func() { + for ch := range channels { + channel, reqs, err := ch.Accept() + assert.NoError(t, err) + go ssh.DiscardRequests(reqs) + go func() { + doneServing <- agent.ServeAgent(tc.LocalAgent(), channel) + channel.Close() + }() + } + }() + require.NoError(t, agent.RequestAgentForwarding(session)) // run a command err = session.Run("cmd") require.NoError(t, err) + // test that SSH agent channel is closed properly + select { + case err := <-doneServing: + require.ErrorIs(t, err, io.EOF) + case <-time.After(3 * time.Second): + require.Fail(t, "timeout waiting for SSH agent channel to be closed") + } + require.NoError(t, nodeClient.Close()) } diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index cd8e21fdbae9b..bf0ec659c1d73 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -1161,6 +1161,7 @@ func (s *Server) handleAgentForward(ch ssh.Channel, req *ssh.Request, ctx *srv.S if err != nil { return trace.Wrap(err) } + ctx.AddCloser(userAgent) } err = agent.ForwardToAgent(ctx.RemoteClient.Client, userAgent)