diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 72afed845ead7..8b71c446ce9fc 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -579,10 +579,6 @@ func (s *APIServer) registerUsingToken(auth ClientI, w http.ResponseWriter, r *h if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } - - // Pass along the remote address the request came from to the registration function. - req.RemoteAddr = r.RemoteAddr - certs, err := auth.RegisterUsingToken(r.Context(), &req) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index f4d0d8babac5c..bc560177fabe7 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -982,6 +982,13 @@ func (a *ServerWithRoles) UpdateUserCARoleMap(ctx context.Context, name string, } func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { + // We do not trust remote addr in the request unless it's coming from the Proxy. + if !a.hasBuiltinRole(types.RoleProxy) || req.RemoteAddr == "" { + if err := setRemoteAddrFromContext(ctx, req); err != nil { + return nil, trace.Wrap(err) + } + } + // tokens have authz mechanism on their own, no need to check return a.authServer.RegisterUsingToken(ctx, req) } diff --git a/lib/auth/bot.go b/lib/auth/bot.go index b6d2e7a7ee61f..c20d998725487 100644 --- a/lib/auth/bot.go +++ b/lib/auth/bot.go @@ -526,7 +526,7 @@ func (a *Server) validateGenerationLabel(ctx context.Context, username string, c // care if the current identity is Nop. This function does not validate the // current identity at all; the caller is expected to validate that the client // is allowed to issue the (possibly renewable) certificates. -func (a *Server) generateInitialBotCerts(ctx context.Context, username string, pubKey []byte, expires time.Time, renewable bool) (*proto.Certs, error) { +func (a *Server) generateInitialBotCerts(ctx context.Context, username, loginIP string, pubKey []byte, expires time.Time, renewable bool) (*proto.Certs, error) { var err error // Extract the user and role set for whom the certificate will be generated. @@ -579,6 +579,7 @@ func (a *Server) generateInitialBotCerts(ctx context.Context, username string, p renewable: renewable, includeHostCA: true, generation: generation, + loginIP: loginIP, } if err := a.validateGenerationLabel(ctx, userState.GetName(), &certReq, 0); err != nil { diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index cefe1964173ad..27a9eaa72f6b8 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -17,11 +17,20 @@ limitations under the License. package auth import ( + "bytes" "context" "crypto/tls" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "net/http" + "strings" "testing" "time" + "github.com/digitorus/pkcs7" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" "golang.org/x/crypto/ssh" @@ -31,6 +40,8 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/wrappers" "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -223,6 +234,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) { PublicSSHKey: publicKey, }) require.NoError(t, err) + checkCertLoginIP(t, certs.TLS, "127.0.0.1") tlsCert, err := tls.X509KeyPair(certs.TLS, privateKey) require.NoError(t, err) @@ -307,3 +319,170 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) { require.NoError(t, err) require.NotEmpty(t, locks) } + +// TestRegisterBot_RemoteAddr checks that certs returned for bot registration contain specified in the request remote addr. +func TestRegisterBot_RemoteAddr(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + p, err := newTestPack(ctx, t.TempDir()) + require.NoError(t, err) + a := p.a + + sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() + require.NoError(t, err) + + tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey) + require.NoError(t, err) + + roleName := "test-role" + _, err = CreateRole(ctx, a, roleName, types.RoleSpecV6{}) + require.NoError(t, err) + + botName := "botty" + _, err = a.createBot(ctx, &proto.CreateBotRequest{ + Name: botName, + Roles: []string{roleName}, + }) + require.NoError(t, err) + + remoteAddr := "42.42.42.42:42" + + t.Run("IAM method", func(t *testing.T) { + a.httpClientForAWSSTS = &mockClient{ + respStatusCode: http.StatusOK, + respBody: responseFromAWSIdentity(awsIdentity{ + Account: "1234", + Arn: "arn:aws::1111", + }), + } + + // add token to auth server + awsTokenName := "aws-test-token" + awsToken, err := types.NewProvisionTokenFromSpec( + awsTokenName, + time.Now().Add(time.Minute), + types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleBot}, + Allow: []*types.TokenRule{ + { + AWSAccount: "1234", + AWSARN: "arn:aws::1111", + }, + }, + BotName: botName, + JoinMethod: types.JoinMethodIAM, + }) + require.NoError(t, err) + require.NoError(t, a.UpsertToken(ctx, awsToken)) + + certs, err := a.RegisterUsingIAMMethod(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + templateInput := defaultIdentityRequestTemplateInput(challenge) + var identityRequest bytes.Buffer + require.NoError(t, identityRequestTemplate.Execute(&identityRequest, templateInput)) + + req := &proto.RegisterUsingIAMMethodRequest{ + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ + Token: awsTokenName, + HostID: "test-bot", + Role: types.RoleBot, + PublicSSHKey: sshPublicKey, + PublicTLSKey: tlsPublicKey, + RemoteAddr: "42.42.42.42:42", + }, + StsIdentityRequest: identityRequest.Bytes(), + } + return req, nil + }) + require.NoError(t, err) + checkCertLoginIP(t, certs.TLS, remoteAddr) + }) + + t.Run("Azure method", func(t *testing.T) { + subID := uuid.NewString() + resourceGroup := "rg" + rsID := resourceID(subID, resourceGroup, "test-vm") + vmID := "vmID" + + accessToken, err := makeToken(rsID, a.clock.Now()) + require.NoError(t, err) + + // add token to auth server + azureTokenName := "azure-test-token" + azureToken, err := types.NewProvisionTokenFromSpec( + azureTokenName, + time.Now().Add(time.Minute), + types.ProvisionTokenSpecV2{ + Roles: []types.SystemRole{types.RoleBot}, + Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}}, + BotName: botName, + JoinMethod: types.JoinMethodAzure, + }) + require.NoError(t, err) + require.NoError(t, a.UpsertToken(ctx, azureToken)) + + vmClient := &mockAzureVMClient{vm: &azure.VirtualMachine{ + ID: rsID, + Name: "test-vm", + Subscription: subID, + ResourceGroup: resourceGroup, + VMID: vmID, + }} + + tlsConfig, err := fixtures.LocalTLSConfig() + require.NoError(t, err) + + block, _ := pem.Decode(fixtures.LocalhostKey) + pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + require.NoError(t, err) + + certs, err := a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { + ad := attestedData{ + Nonce: challenge, + SubscriptionID: subID, + ID: vmID, + } + adBytes, err := json.Marshal(&ad) + require.NoError(t, err) + s, err := pkcs7.NewSignedData(adBytes) + require.NoError(t, err) + require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) + signature, err := s.Finish() + require.NoError(t, err) + signedAD := signedAttestedData{ + Encoding: "pkcs7", + Signature: base64.StdEncoding.EncodeToString(signature), + } + signedADBytes, err := json.Marshal(&signedAD) + require.NoError(t, err) + + req := &proto.RegisterUsingAzureMethodRequest{ + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ + Token: azureTokenName, + HostID: "test-node", + Role: types.RoleBot, + PublicSSHKey: sshPublicKey, + PublicTLSKey: tlsPublicKey, + RemoteAddr: remoteAddr, + }, + AttestedData: signedADBytes, + AccessToken: accessToken, + } + return req, nil + }, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClient(vmClient)) + require.NoError(t, err) + checkCertLoginIP(t, certs.TLS, remoteAddr) + }) +} + +func checkCertLoginIP(t *testing.T, certBytes []byte, loginIP string) { + t.Helper() + + cert, err := tlsca.ParseCertificatePEM(certBytes) + require.NoError(t, err) + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + require.True(t, strings.HasPrefix(identity.LoginIP, loginIP)) +} diff --git a/lib/auth/join.go b/lib/auth/join.go index bfc2a08e850ec..7be9e3321bc15 100644 --- a/lib/auth/join.go +++ b/lib/auth/join.go @@ -21,14 +21,17 @@ import ( "crypto/rand" "encoding/base64" "fmt" + "net" "strings" "github.com/gravitational/trace" "golang.org/x/exp/slices" + "google.golang.org/grpc/peer" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" ) @@ -92,6 +95,22 @@ type joinAttributeSourcer interface { JoinAuditAttributes() (map[string]interface{}, error) } +func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingTokenRequest) error { + var addr string + if clientIP, err := authz.ClientSrcAddrFromContext(ctx); err == nil { + addr = clientIP.String() + } else if p, ok := peer.FromContext(ctx); ok { + addr = p.Addr.String() + } + ip, _, err := net.SplitHostPort(addr) + if err != nil { + return trace.Wrap(err) + } + req.RemoteAddr = ip + + return nil +} + // RegisterUsingToken returns credentials for a new node to join the Teleport // cluster using a previously issued token. // @@ -217,7 +236,7 @@ func (a *Server) generateCertsBot( } certs, err := a.generateInitialBotCerts( - ctx, BotResourceName(botName), req.PublicSSHKey, expires, renewable, + ctx, BotResourceName(botName), req.RemoteAddr, req.PublicSSHKey, expires, renewable, ) if err != nil { return nil, trace.Wrap(err) diff --git a/lib/auth/join_azure.go b/lib/auth/join_azure.go index c370b1f959ff6..10fbca8853923 100644 --- a/lib/auth/join_azure.go +++ b/lib/auth/join_azure.go @@ -36,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/utils" ) @@ -348,11 +347,6 @@ func (a *Server) RegisterUsingAzureMethod(ctx context.Context, challengeResponse return nil, trace.Wrap(err) } - clientAddr, err := authz.ClientSrcAddrFromContext(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - challenge, err := generateAzureChallenge() if err != nil { return nil, trace.Wrap(err) @@ -362,7 +356,6 @@ func (a *Server) RegisterUsingAzureMethod(ctx context.Context, challengeResponse return nil, trace.Wrap(err) } - req.RegisterUsingTokenRequest.RemoteAddr = clientAddr.String() if err := req.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/join_azure_test.go b/lib/auth/join_azure_test.go index 439d73d390ea0..72f568555ec52 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/auth/join_azure_test.go @@ -23,7 +23,6 @@ import ( "encoding/json" "encoding/pem" "fmt" - "net" "testing" "time" @@ -37,7 +36,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/cloud/azure" "github.com/gravitational/teleport/lib/fixtures" ) @@ -384,9 +382,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { accessToken, err := makeToken(rsID, a.clock.Now()) require.NoError(t, err) - reqCtx := context.Background() - reqCtx = authz.ContextWithClientSrcAddr(reqCtx, &net.IPAddr{}) - vmResult := tc.vmResult if vmResult == nil { vmResult = &azure.VirtualMachine{ @@ -400,7 +395,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { vmClient := &mockAzureVMClient{vm: vmResult} - _, err = a.RegisterUsingAzureMethod(reqCtx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { + _, err = a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { cfg := &azureChallengeResponseConfig{Challenge: challenge} for _, opt := range tc.challengeResponseOptions { opt(cfg) diff --git a/lib/auth/join_iam.go b/lib/auth/join_iam.go index 5fba4b844d2ec..54367306c1512 100644 --- a/lib/auth/join_iam.go +++ b/lib/auth/join_iam.go @@ -38,7 +38,6 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/authz" cloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/aws" @@ -348,11 +347,6 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c opt(cfg) } - clientAddr, err := authz.ClientSrcAddrFromContext(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - challenge, err := generateIAMChallenge() if err != nil { return nil, trace.Wrap(err) @@ -363,8 +357,6 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c return nil, trace.Wrap(err) } - // fill in the client remote addr to the register request - req.RegisterUsingTokenRequest.RemoteAddr = clientAddr.String() if err := req.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/join_iam_test.go b/lib/auth/join_iam_test.go index 22ee297972f3f..6ab0d11f8078c 100644 --- a/lib/auth/join_iam_test.go +++ b/lib/auth/join_iam_test.go @@ -21,7 +21,6 @@ import ( "context" "fmt" "io" - "net" "net/http" "strings" "testing" @@ -35,7 +34,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/utils" ) @@ -505,10 +503,7 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) { require.NoError(t, a.DeleteToken(ctx, token.GetName())) }() - requestContext := context.Background() - requestContext = authz.ContextWithClientSrcAddr(requestContext, &net.IPAddr{}) - - _, err = a.RegisterUsingIAMMethod(requestContext, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { + _, err = a.RegisterUsingIAMMethod(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) { templateInput := defaultIdentityRequestTemplateInput(challenge) for _, opt := range tc.challengeResponseOptions { opt(&templateInput) diff --git a/lib/joinserver/joinserver.go b/lib/joinserver/joinserver.go index 23ee163fc2d73..a8713dcd6fd0b 100644 --- a/lib/joinserver/joinserver.go +++ b/lib/joinserver/joinserver.go @@ -20,6 +20,8 @@ package joinserver import ( "context" + "net" + "slices" "time" "github.com/gravitational/trace" @@ -29,6 +31,9 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/authz" + "github.com/gravitational/teleport/lib/tlsca" ) const ( @@ -109,7 +114,14 @@ func (s *JoinServiceGRPCServer) registerUsingIAMMethod(ctx context.Context, srv // Then get the response from the client and return it. req, err := srv.Recv() - return req, trace.Wrap(err) + if err != nil { + return nil, trace.Wrap(err) + } + if err := setClientRemoteAddr(ctx, req.RegisterUsingTokenRequest); err != nil { + return nil, trace.Wrap(err) + } + + return req, nil }) if err != nil { return trace.Wrap(err) @@ -160,6 +172,28 @@ func (s *JoinServiceGRPCServer) RegisterUsingAzureMethod(srv proto.JoinService_R } } +func checkForProxyRole(identity tlsca.Identity) bool { + const proxyRole = string(types.RoleProxy) + return slices.Contains(identity.Groups, proxyRole) || slices.Contains(identity.SystemRoles, proxyRole) +} + +func setClientRemoteAddr(ctx context.Context, req *types.RegisterUsingTokenRequest) error { + // If request is coming from the Proxy, trust the IP set on the request. + if user, err := authz.UserFromContext(ctx); err == nil && checkForProxyRole(user.GetIdentity()) { + return nil + } + // Otherwise this is (likely) the proxy, set the IP from the connection. + p, ok := peer.FromContext(ctx) + if !ok { + return trace.BadParameter("could not get peer from the context") + } + req.RemoteAddr = p.Addr.String() // Addr without port is used in tests. + if ip, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + req.RemoteAddr = ip + } + return nil +} + func (s *JoinServiceGRPCServer) registerUsingAzureMethod(ctx context.Context, srv proto.JoinService_RegisterUsingAzureMethodServer) error { certs, err := s.joinServiceClient.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { err := srv.Send(&proto.RegisterUsingAzureMethodResponse{ @@ -170,7 +204,14 @@ func (s *JoinServiceGRPCServer) registerUsingAzureMethod(ctx context.Context, sr } req, err := srv.Recv() - return req, trace.Wrap(err) + if err != nil { + return nil, trace.Wrap(err) + } + if err := setClientRemoteAddr(ctx, req.RegisterUsingTokenRequest); err != nil { + return nil, trace.Wrap(err) + } + + return req, nil }) if err != nil { return trace.Wrap(err) diff --git a/lib/joinserver/joinserver_test.go b/lib/joinserver/joinserver_test.go index 18dd20cd12f56..6e79818e6cba9 100644 --- a/lib/joinserver/joinserver_test.go +++ b/lib/joinserver/joinserver_test.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/grpc/interceptors" ) @@ -171,16 +172,18 @@ func TestJoinServiceGRPCServer_RegisterUsingIAMMethod(t *testing.T) { certs *proto.Certs }{ { - desc: "pass case", - challenge: "foo", - challengeResponse: &proto.RegisterUsingIAMMethodRequest{StsIdentityRequest: []byte("bar")}, - certs: &proto.Certs{SSH: []byte("baz")}, + desc: "pass case", + challenge: "foo", + challengeResponse: &proto.RegisterUsingIAMMethodRequest{StsIdentityRequest: []byte("bar"), + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{}}, + certs: &proto.Certs{SSH: []byte("baz")}, }, { - desc: "auth error", - challenge: "foo", - challengeResponse: &proto.RegisterUsingIAMMethodRequest{StsIdentityRequest: []byte("bar")}, - authErr: trace.AccessDenied("test auth error"), + desc: "auth error", + challenge: "foo", + challengeResponse: &proto.RegisterUsingIAMMethodRequest{StsIdentityRequest: []byte("bar"), + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{}}, + authErr: trace.AccessDenied("test auth error"), }, { desc: "challenge response error", @@ -221,8 +224,10 @@ func TestJoinServiceGRPCServer_RegisterUsingIAMMethod(t *testing.T) { require.NoError(t, err) // client should get the certs from auth require.Equal(t, tc.certs, certs) - // auth should get the challenge response from client - require.Equal(t, tc.challengeResponse, testPack.mockAuthServer.gotIAMChallengeResponse) + // auth should get the challenge response from client with remote addr set to connection src addr + expectedResponse := tc.challengeResponse + expectedResponse.RegisterUsingTokenRequest.RemoteAddr = "bufconn" + require.Equal(t, expectedResponse, testPack.mockAuthServer.gotIAMChallengeResponse) }) } }) @@ -242,16 +247,18 @@ func TestJoinServiceGRPCServer_RegisterUsingAzureMethod(t *testing.T) { certs *proto.Certs }{ { - desc: "pass case", - challenge: "foo", - challengeResponse: &proto.RegisterUsingAzureMethodRequest{AttestedData: []byte("bar"), AccessToken: "baz"}, - certs: &proto.Certs{SSH: []byte("qux")}, + desc: "pass case", + challenge: "foo", + challengeResponse: &proto.RegisterUsingAzureMethodRequest{AttestedData: []byte("bar"), AccessToken: "baz", + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{}}, + certs: &proto.Certs{SSH: []byte("qux")}, }, { - desc: "auth error", - challenge: "foo", - challengeResponse: &proto.RegisterUsingAzureMethodRequest{AttestedData: []byte("bar"), AccessToken: "baz"}, - authErr: trace.AccessDenied("test auth error"), + desc: "auth error", + challenge: "foo", + challengeResponse: &proto.RegisterUsingAzureMethodRequest{AttestedData: []byte("bar"), AccessToken: "baz", + RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{}}, + authErr: trace.AccessDenied("test auth error"), }, { desc: "challenge response error", @@ -285,7 +292,9 @@ func TestJoinServiceGRPCServer_RegisterUsingAzureMethod(t *testing.T) { } require.NoError(t, err) require.Equal(t, tc.certs, certs) - require.Equal(t, tc.challengeResponse, testPack.mockAuthServer.gotAzureChallengeResponse) + expectedResponse := tc.challengeResponse + expectedResponse.RegisterUsingTokenRequest.RemoteAddr = "bufconn" + require.Equal(t, expectedResponse, testPack.mockAuthServer.gotAzureChallengeResponse) }) } }) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index c83574d6e3fad..09bab28fe9ba2 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3566,6 +3566,11 @@ func (h *Handler) hostCredentials(w http.ResponseWriter, r *http.Request, p http } authClient := h.cfg.ProxyClient + remoteAddr, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return nil, trace.Wrap(err) + } + req.RemoteAddr = remoteAddr certs, err := authClient.RegisterUsingToken(r.Context(), &req) if err != nil { return nil, trace.Wrap(err)