Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,6 @@ func (s *APIServer) registerUsingToken(auth ClientI, w http.ResponseWriter, r *h
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}

// Pass along the remote address the request came from to the registration function.
req.RemoteAddr = r.RemoteAddr

certs, err := auth.RegisterUsingToken(r.Context(), &req)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
7 changes: 7 additions & 0 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,13 @@ func (a *ServerWithRoles) UpdateUserCARoleMap(ctx context.Context, name string,
}

func (a *ServerWithRoles) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) {
// We do not trust remote addr in the request unless it's coming from the Proxy.
if !a.hasBuiltinRole(types.RoleProxy) || req.RemoteAddr == "" {
if err := setRemoteAddrFromContext(ctx, req); err != nil {
return nil, trace.Wrap(err)
}
}

// tokens have authz mechanism on their own, no need to check
return a.authServer.RegisterUsingToken(ctx, req)
}
Expand Down
3 changes: 2 additions & 1 deletion lib/auth/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ func (a *Server) validateGenerationLabel(ctx context.Context, username string, c
// care if the current identity is Nop. This function does not validate the
// current identity at all; the caller is expected to validate that the client
// is allowed to issue the (possibly renewable) certificates.
func (a *Server) generateInitialBotCerts(ctx context.Context, username string, pubKey []byte, expires time.Time, renewable bool) (*proto.Certs, error) {
func (a *Server) generateInitialBotCerts(ctx context.Context, username, loginIP string, pubKey []byte, expires time.Time, renewable bool) (*proto.Certs, error) {
var err error

// Extract the user and role set for whom the certificate will be generated.
Expand Down Expand Up @@ -579,6 +579,7 @@ func (a *Server) generateInitialBotCerts(ctx context.Context, username string, p
renewable: renewable,
includeHostCA: true,
generation: generation,
loginIP: loginIP,
}

if err := a.validateGenerationLabel(ctx, userState.GetName(), &certReq, 0); err != nil {
Expand Down
179 changes: 179 additions & 0 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,20 @@ limitations under the License.
package auth

import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"net/http"
"strings"
"testing"
"time"

"github.com/digitorus/pkcs7"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
Expand All @@ -31,6 +40,8 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/wrappers"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -223,6 +234,7 @@ func TestRegisterBotCertificateGenerationCheck(t *testing.T) {
PublicSSHKey: publicKey,
})
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, "127.0.0.1")

tlsCert, err := tls.X509KeyPair(certs.TLS, privateKey)
require.NoError(t, err)
Expand Down Expand Up @@ -307,3 +319,170 @@ func TestRegisterBotCertificateGenerationStolen(t *testing.T) {
require.NoError(t, err)
require.NotEmpty(t, locks)
}

// TestRegisterBot_RemoteAddr checks that certs returned for bot registration contain specified in the request remote addr.
func TestRegisterBot_RemoteAddr(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

p, err := newTestPack(ctx, t.TempDir())
require.NoError(t, err)
a := p.a

sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair()
require.NoError(t, err)

tlsPublicKey, err := PrivateKeyToPublicKeyTLS(sshPrivateKey)
require.NoError(t, err)

roleName := "test-role"
_, err = CreateRole(ctx, a, roleName, types.RoleSpecV6{})
require.NoError(t, err)

botName := "botty"
_, err = a.createBot(ctx, &proto.CreateBotRequest{
Name: botName,
Roles: []string{roleName},
})
require.NoError(t, err)

remoteAddr := "42.42.42.42:42"

t.Run("IAM method", func(t *testing.T) {
a.httpClientForAWSSTS = &mockClient{
respStatusCode: http.StatusOK,
respBody: responseFromAWSIdentity(awsIdentity{
Account: "1234",
Arn: "arn:aws::1111",
}),
}

// add token to auth server
awsTokenName := "aws-test-token"
awsToken, err := types.NewProvisionTokenFromSpec(
awsTokenName,
time.Now().Add(time.Minute),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleBot},
Allow: []*types.TokenRule{
{
AWSAccount: "1234",
AWSARN: "arn:aws::1111",
},
},
BotName: botName,
JoinMethod: types.JoinMethodIAM,
})
require.NoError(t, err)
require.NoError(t, a.UpsertToken(ctx, awsToken))

certs, err := a.RegisterUsingIAMMethod(context.Background(), func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
templateInput := defaultIdentityRequestTemplateInput(challenge)
var identityRequest bytes.Buffer
require.NoError(t, identityRequestTemplate.Execute(&identityRequest, templateInput))

req := &proto.RegisterUsingIAMMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: awsTokenName,
HostID: "test-bot",
Role: types.RoleBot,
PublicSSHKey: sshPublicKey,
PublicTLSKey: tlsPublicKey,
RemoteAddr: "42.42.42.42:42",
},
StsIdentityRequest: identityRequest.Bytes(),
}
return req, nil
})
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})

t.Run("Azure method", func(t *testing.T) {
subID := uuid.NewString()
resourceGroup := "rg"
rsID := resourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, a.clock.Now())
require.NoError(t, err)

// add token to auth server
azureTokenName := "azure-test-token"
azureToken, err := types.NewProvisionTokenFromSpec(
azureTokenName,
time.Now().Add(time.Minute),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleBot},
Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}},
BotName: botName,
JoinMethod: types.JoinMethodAzure,
})
require.NoError(t, err)
require.NoError(t, a.UpsertToken(ctx, azureToken))

vmClient := &mockAzureVMClient{vm: &azure.VirtualMachine{
ID: rsID,
Name: "test-vm",
Subscription: subID,
ResourceGroup: resourceGroup,
VMID: vmID,
}}

tlsConfig, err := fixtures.LocalTLSConfig()
require.NoError(t, err)

block, _ := pem.Decode(fixtures.LocalhostKey)
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

certs, err := a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
ad := attestedData{
Nonce: challenge,
SubscriptionID: subID,
ID: vmID,
}
adBytes, err := json.Marshal(&ad)
require.NoError(t, err)
s, err := pkcs7.NewSignedData(adBytes)
require.NoError(t, err)
require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
signature, err := s.Finish()
require.NoError(t, err)
signedAD := signedAttestedData{
Encoding: "pkcs7",
Signature: base64.StdEncoding.EncodeToString(signature),
}
signedADBytes, err := json.Marshal(&signedAD)
require.NoError(t, err)

req := &proto.RegisterUsingAzureMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: azureTokenName,
HostID: "test-node",
Role: types.RoleBot,
PublicSSHKey: sshPublicKey,
PublicTLSKey: tlsPublicKey,
RemoteAddr: remoteAddr,
},
AttestedData: signedADBytes,
AccessToken: accessToken,
}
return req, nil
}, withCerts([]*x509.Certificate{tlsConfig.Certificate}), withVerifyFunc(mockVerifyToken(nil)), withVMClient(vmClient))
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})
}

func checkCertLoginIP(t *testing.T, certBytes []byte, loginIP string) {
t.Helper()

cert, err := tlsca.ParseCertificatePEM(certBytes)
require.NoError(t, err)
identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter)
require.NoError(t, err)
require.True(t, strings.HasPrefix(identity.LoginIP, loginIP))
}
21 changes: 20 additions & 1 deletion lib/auth/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@ import (
"crypto/rand"
"encoding/base64"
"fmt"
"net"
"strings"

"github.com/gravitational/trace"
"golang.org/x/exp/slices"
"google.golang.org/grpc/peer"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
)
Expand Down Expand Up @@ -92,6 +95,22 @@ type joinAttributeSourcer interface {
JoinAuditAttributes() (map[string]interface{}, error)
}

func setRemoteAddrFromContext(ctx context.Context, req *types.RegisterUsingTokenRequest) error {
var addr string
if clientIP, err := authz.ClientSrcAddrFromContext(ctx); err == nil {
addr = clientIP.String()
} else if p, ok := peer.FromContext(ctx); ok {
addr = p.Addr.String()
}
ip, _, err := net.SplitHostPort(addr)
if err != nil {
return trace.Wrap(err)
}
req.RemoteAddr = ip

return nil
}

// RegisterUsingToken returns credentials for a new node to join the Teleport
// cluster using a previously issued token.
//
Expand Down Expand Up @@ -217,7 +236,7 @@ func (a *Server) generateCertsBot(
}

certs, err := a.generateInitialBotCerts(
ctx, BotResourceName(botName), req.PublicSSHKey, expires, renewable,
ctx, BotResourceName(botName), req.RemoteAddr, req.PublicSSHKey, expires, renewable,
)
if err != nil {
return nil, trace.Wrap(err)
Expand Down
7 changes: 0 additions & 7 deletions lib/auth/join_azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ import (
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/utils"
)
Expand Down Expand Up @@ -348,11 +347,6 @@ func (a *Server) RegisterUsingAzureMethod(ctx context.Context, challengeResponse
return nil, trace.Wrap(err)
}

clientAddr, err := authz.ClientSrcAddrFromContext(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

challenge, err := generateAzureChallenge()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -362,7 +356,6 @@ func (a *Server) RegisterUsingAzureMethod(ctx context.Context, challengeResponse
return nil, trace.Wrap(err)
}

req.RegisterUsingTokenRequest.RemoteAddr = clientAddr.String()
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
7 changes: 1 addition & 6 deletions lib/auth/join_azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"encoding/json"
"encoding/pem"
"fmt"
"net"
"testing"
"time"

Expand All @@ -37,7 +36,6 @@ import (
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/authz"
"github.com/gravitational/teleport/lib/cloud/azure"
"github.com/gravitational/teleport/lib/fixtures"
)
Expand Down Expand Up @@ -384,9 +382,6 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {
accessToken, err := makeToken(rsID, a.clock.Now())
require.NoError(t, err)

reqCtx := context.Background()
reqCtx = authz.ContextWithClientSrcAddr(reqCtx, &net.IPAddr{})

vmResult := tc.vmResult
if vmResult == nil {
vmResult = &azure.VirtualMachine{
Expand All @@ -400,7 +395,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) {

vmClient := &mockAzureVMClient{vm: vmResult}

_, err = a.RegisterUsingAzureMethod(reqCtx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
_, err = a.RegisterUsingAzureMethod(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
cfg := &azureChallengeResponseConfig{Challenge: challenge}
for _, opt := range tc.challengeResponseOptions {
opt(cfg)
Expand Down
8 changes: 0 additions & 8 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ import (
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/authz"
cloudaws "github.com/gravitational/teleport/lib/cloud/aws"
"github.com/gravitational/teleport/lib/utils"
"github.com/gravitational/teleport/lib/utils/aws"
Expand Down Expand Up @@ -348,11 +347,6 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c
opt(cfg)
}

clientAddr, err := authz.ClientSrcAddrFromContext(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

challenge, err := generateIAMChallenge()
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -363,8 +357,6 @@ func (a *Server) RegisterUsingIAMMethod(ctx context.Context, challengeResponse c
return nil, trace.Wrap(err)
}

// fill in the client remote addr to the register request
req.RegisterUsingTokenRequest.RemoteAddr = clientAddr.String()
if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading