Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ import (
iterstream "github.com/gravitational/teleport/lib/itertools/stream"
"github.com/gravitational/teleport/lib/join"
"github.com/gravitational/teleport/lib/join/azuredevops"
"github.com/gravitational/teleport/lib/join/azurejoin"
"github.com/gravitational/teleport/lib/join/bitbucket"
joinboundkeypair "github.com/gravitational/teleport/lib/join/boundkeypair"
"github.com/gravitational/teleport/lib/join/circleci"
Expand Down Expand Up @@ -1339,6 +1340,9 @@ type Server struct {
// override the implementation used in tests.
env0IDTokenValidator join.Env0TokenValidator

// azureJoinConfig holds configuration for the Azure join method.
azureJoinConfig *azurejoin.AzureJoinConfig

// loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node.
loadAllCAs bool

Expand Down
92 changes: 0 additions & 92 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import (
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
Expand All @@ -36,10 +32,8 @@ import (
"text/template"
"time"

"github.com/digitorus/pkcs7"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand All @@ -58,16 +52,13 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/authtest"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/cloud/azure"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/join/iamjoin"
"github.com/gravitational/teleport/lib/join/joinclient"
"github.com/gravitational/teleport/lib/kube/token"
Expand Down Expand Up @@ -741,89 +732,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})

t.Run("Azure method", func(t *testing.T) {
subID := uuid.NewString()
resourceGroup := "rg"
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, "", a.GetClock().Now())
require.NoError(t, err)

// add token to auth server
azureTokenName := "azure-test-token"
azureToken, err := types.NewProvisionTokenFromSpec(
azureTokenName,
time.Now().Add(time.Minute),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleBot},
Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}},
BotName: botName,
JoinMethod: types.JoinMethodAzure,
})
require.NoError(t, err)
require.NoError(t, a.UpsertToken(ctx, azureToken))

vmClient := &mockAzureVMClient{
vms: map[string]*azure.VirtualMachine{
rsID: {
ID: rsID,
Name: "test-vm",
Subscription: subID,
ResourceGroup: resourceGroup,
VMID: vmID,
},
},
}
getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
subID: vmClient,
})

tlsConfig, err := fixtures.LocalTLSConfig()
require.NoError(t, err)

block, _ := pem.Decode(fixtures.LocalhostKey)
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

certs, err := a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
ad := auth.AttestedData{
Nonce: challenge,
SubscriptionID: subID,
ID: vmID,
}
adBytes, err := json.Marshal(&ad)
require.NoError(t, err)
s, err := pkcs7.NewSignedData(adBytes)
require.NoError(t, err)
require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
signature, err := s.Finish()
require.NoError(t, err)
signedAD := auth.SignedAttestedData{
Encoding: "pkcs7",
Signature: base64.StdEncoding.EncodeToString(signature),
}
signedADBytes, err := json.Marshal(&signedAD)
require.NoError(t, err)

req := &proto.RegisterUsingAzureMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: azureTokenName,
HostID: "test-node",
Role: types.RoleBot,
PublicSSHKey: sshPubKey,
PublicTLSKey: tlsPubKey,
RemoteAddr: remoteAddr,
},
AttestedData: signedADBytes,
AccessToken: accessToken,
}
return req, nil
}, auth.WithAzureCerts([]*x509.Certificate{tlsConfig.Certificate}), auth.WithAzureVerifyFunc(mockVerifyToken(nil)), auth.WithAzureVMClientGetter(getVMClient))
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})
}

func responseFromAWSIdentity(id iamjoin.AWSIdentity) string {
Expand Down
31 changes: 0 additions & 31 deletions lib/auth/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ const (

MaxUserAgentLen = maxUserAgentLen
ForwardedTag = forwardedTag

AzureAccessTokenAudience = azureAccessTokenAudience
)

var (
Expand Down Expand Up @@ -307,10 +305,6 @@ func TrimUserAgent(userAgent string) string {
return trimUserAgent(userAgent)
}

func IsAllowedDomain(cn string, domains []string) bool {
return isAllowedDomain(cn, domains)
}

func GetSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) {
return getSnowflakeJWTParams(ctx, accountName, userName, publicKey)
}
Expand All @@ -336,31 +330,6 @@ func CheckHeaders(headers http.Header, challenge string, clock clockwork.Clock)
}

type GitHubManager = githubManager
type AttestedData = attestedData
type SignedAttestedData = signedAttestedData
type AzureRegisterOption = azureRegisterOption
type AzureRegisterConfig = azureRegisterConfig
type AzureVMClientGetter = vmClientGetter
type AzureVerifyTokenFunc = azureVerifyTokenFunc
type AccessTokenClaims = accessTokenClaims

func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.certificateAuthorities = certs
}
}

func WithAzureVerifyFunc(verify azureVerifyTokenFunc) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.verify = verify
}
}

func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.getVMClient = getVMClient
}
}

func (s *TLSServer) GRPCServer() *GRPCServer {
return s.grpcServer
Expand Down
18 changes: 17 additions & 1 deletion lib/auth/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ type AzureParams struct {
// ClientID is the client ID of the managed identity for Teleport to assume
// when authenticating a node.
ClientID string
// IMDSClient overrides the client used to fetch data from Azure IMDS.
IMDSClient AzureIMDSClient
// IssuerHTTPClient, if set, overrides the default HTTP client used to
// fetch the intermediate CA which issued the attested data document
// signing certificate. Only used when joining via the new join service.
IssuerHTTPClient utils.HTTPDoClient
}

// AzureIMDSClient is a client to Azure's IMDS.
type AzureIMDSClient interface {
IsAvailable(context.Context) bool
GetAttestedData(ctx context.Context, nonce string) ([]byte, error)
GetAccessToken(ctx context.Context, clientID string) (string, error)
}

// GitlabParams is the parameters specific to the gitlab join method.
Expand Down Expand Up @@ -875,7 +888,10 @@ func registerUsingAzureMethod(
ctx context.Context, client joinServiceClient, token string, hostKeys *newHostKeys, params RegisterParams,
) (*proto.Certs, error) {
certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
imds := azure.NewInstanceMetadataClient()
imds := params.AzureParams.IMDSClient
if imds == nil {
imds = azure.NewInstanceMetadataClient()
}
if !imds.IsAvailable(ctx) {
return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
}
Expand Down
119 changes: 119 additions & 0 deletions lib/auth/join_azure_legacy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Teleport
* Copyright (C) 2023 Gravitational, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

package auth

import (
"context"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/join/azurejoin"
"github.com/gravitational/teleport/lib/join/legacyjoin"
)

// RegisterUsingAzureMethod registers the caller using the Azure join method
// and returns signed certs to join the cluster.
//
// The caller must provide a ChallengeResponseFunc which returns a
// *proto.RegisterUsingAzureMethodRequest with a signed attested data document
// including the challenge as a nonce.
//
// TODO(nklaassen): DELETE IN 20 when removing the legacy join service.
func (a *Server) RegisterUsingAzureMethod(
ctx context.Context,
challengeResponse client.RegisterAzureChallengeResponseFunc,
) (certs *proto.Certs, err error) {
var provisionToken types.ProvisionToken
var joinRequest *types.RegisterUsingTokenRequest
defer func() {
// Emit a log message and audit event on join failure.
if err != nil {
a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest)
}
}()

if legacyjoin.Disabled() {
return nil, trace.Wrap(legacyjoin.ErrDisabled)
}

challenge, err := azurejoin.GenerateAzureChallenge()
if err != nil {
return nil, trace.Wrap(err)
}
req, err := challengeResponse(challenge)
if err != nil {
return nil, trace.Wrap(err)
}
joinRequest = req.RegisterUsingTokenRequest

if err := req.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}

provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest)
if err != nil {
return nil, trace.Wrap(err)
}
if provisionToken.GetJoinMethod() != types.JoinMethodAzure {
return nil, trace.AccessDenied("this token does not support the Azure join method")
}

ptv2, ok := provisionToken.(*types.ProvisionTokenV2)
if !ok {
return nil, trace.Wrap(err, "Azure join method only supports ProvisionTokenV2, got %T", provisionToken)
}

joinAttrs, err := azurejoin.CheckAzureRequest(ctx, azurejoin.CheckAzureRequestParams{
AzureJoinConfig: a.GetAzureJoinConfig(),
Token: ptv2,
Challenge: challenge,
AttestedData: req.AttestedData,
AccessToken: req.AccessToken,
Logger: a.logger,
Clock: a.GetClock(),
})
if err != nil {
return nil, trace.Wrap(err, "checking Azure challenge response")
}

if req.RegisterUsingTokenRequest.Role == types.RoleBot {
params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{
Azure: joinAttrs,
})
certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params)
return certs, trace.Wrap(err)
}
params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/)
certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params)
return certs, trace.Wrap(err)
}

// GetAzureJoinConfig gets configuration options for azure joining.
func (a *Server) GetAzureJoinConfig() *azurejoin.AzureJoinConfig {
return a.azureJoinConfig
}

// SetAzureJoinConfig sets configuration options for azure joining.
func (a *Server) SetAzureJoinConfig(c *azurejoin.AzureJoinConfig) {
a.azureJoinConfig = c
}
Loading
Loading