Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
436 changes: 339 additions & 97 deletions api/gen/proto/go/teleport/join/v1/joinservice.pb.go

Large diffs are not rendered by default.

41 changes: 41 additions & 0 deletions api/proto/teleport/join/v1/joinservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,44 @@ message TPMSolution {
bytes solution = 1;
}

// AzureInit is sent from the client in response to the ServerInit message for
// the Azure join method.
//
// The Azure method join flow is:
// 1. client->server: ClientInit
// 2. client<-server: ServerInit
// 3. client->server: AzureInit
// 4. client<-server: AzureChallenge
// 5. client->server: AzureChallengeSolution
// 6. client<-server: Result
message AzureInit {
// ClientParams holds parameters for the specific type of client trying to join.
ClientParams client_params = 1;
}

// AzureChallenge is sent from the server in response to the AzureInit message from the client.
// The client is expected to respond with a AzureChallengeSolution.
message AzureChallenge {
// Challenge is a a crypto-random string that should be included by the
// client in the challenge response message.
string challenge = 1;
}

// AzureChallengeSolution must be sent from the client in response to the
// AzureChallenge message.
message AzureChallengeSolution {
// AttestedData is a signed JSON document from an Azure VM's attested data
// metadata endpoint used to prove the identity of a joining node. It must
// include the challenge string as the nonce.
bytes attested_data = 1;
// Intermediate encodes the intermediate CAs that issued the leaf certificate
// used to sign the attested data document, in x509 DER format.
bytes intermediate = 2;
// AccessToken is a JWT signed by Azure, used to prove the identity of a
// joining node.
string access_token = 3;
}

// ChallengeSolution holds a solution to a challenge issued by the server.
message ChallengeSolution {
oneof payload {
Expand All @@ -357,6 +395,7 @@ message ChallengeSolution {
IAMChallengeSolution iam_challenge_solution = 3;
OracleChallengeSolution oracle_challenge_solution = 4;
TPMSolution tpm_solution = 5;
AzureChallengeSolution azure_challenge_solution = 6;
}
}

Expand Down Expand Up @@ -396,6 +435,7 @@ message JoinRequest {
OIDCInit oidc_init = 8;
OracleInit oracle_init = 9;
TPMInit tpm_init = 10;
AzureInit azure_init = 11;
}
}

Expand All @@ -417,6 +457,7 @@ message Challenge {
IAMChallenge iam_challenge = 3;
OracleChallenge oracle_challenge = 4;
TPMEncryptedCredential tpm_encrypted_credential = 5;
AzureChallenge azure_challenge = 6;
}
}

Expand Down
4 changes: 4 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ import (
iterstream "github.com/gravitational/teleport/lib/itertools/stream"
"github.com/gravitational/teleport/lib/join"
"github.com/gravitational/teleport/lib/join/azuredevops"
"github.com/gravitational/teleport/lib/join/azurejoin"
"github.com/gravitational/teleport/lib/join/bitbucket"
joinboundkeypair "github.com/gravitational/teleport/lib/join/boundkeypair"
"github.com/gravitational/teleport/lib/join/circleci"
Expand Down Expand Up @@ -1329,6 +1330,9 @@ type Server struct {
// override the implementation used in tests.
env0IDTokenValidator join.Env0TokenValidator

// azureJoinConfig holds configuration for the Azure join method.
azureJoinConfig *azurejoin.AzureJoinConfig

// loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node.
loadAllCAs bool

Expand Down
92 changes: 0 additions & 92 deletions lib/auth/bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ import (
"context"
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"io"
Expand All @@ -36,10 +32,8 @@ import (
"text/template"
"time"

"github.com/digitorus/pkcs7"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/require"
Expand All @@ -58,16 +52,13 @@ import (
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/authtest"
"github.com/gravitational/teleport/lib/auth/machineid/machineidv1"
"github.com/gravitational/teleport/lib/auth/state"
"github.com/gravitational/teleport/lib/auth/testauthority"
"github.com/gravitational/teleport/lib/cloud/azure"
libevents "github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/eventstest"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/join/iamjoin"
"github.com/gravitational/teleport/lib/join/joinclient"
"github.com/gravitational/teleport/lib/kube/token"
Expand Down Expand Up @@ -744,89 +735,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) {
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})

t.Run("Azure method", func(t *testing.T) {
subID := uuid.NewString()
resourceGroup := "rg"
rsID := vmResourceID(subID, resourceGroup, "test-vm")
vmID := "vmID"

accessToken, err := makeToken(rsID, "", a.GetClock().Now())
require.NoError(t, err)

// add token to auth server
azureTokenName := "azure-test-token"
azureToken, err := types.NewProvisionTokenFromSpec(
azureTokenName,
time.Now().Add(time.Minute),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleBot},
Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}},
BotName: botName,
JoinMethod: types.JoinMethodAzure,
})
require.NoError(t, err)
require.NoError(t, a.UpsertToken(ctx, azureToken))

vmClient := &mockAzureVMClient{
vms: map[string]*azure.VirtualMachine{
rsID: {
ID: rsID,
Name: "test-vm",
Subscription: subID,
ResourceGroup: resourceGroup,
VMID: vmID,
},
},
}
getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{
subID: vmClient,
})

tlsConfig, err := fixtures.LocalTLSConfig()
require.NoError(t, err)

block, _ := pem.Decode(fixtures.LocalhostKey)
pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
require.NoError(t, err)

certs, err := a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
ad := auth.AttestedData{
Nonce: challenge,
SubscriptionID: subID,
ID: vmID,
}
adBytes, err := json.Marshal(&ad)
require.NoError(t, err)
s, err := pkcs7.NewSignedData(adBytes)
require.NoError(t, err)
require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{}))
signature, err := s.Finish()
require.NoError(t, err)
signedAD := auth.SignedAttestedData{
Encoding: "pkcs7",
Signature: base64.StdEncoding.EncodeToString(signature),
}
signedADBytes, err := json.Marshal(&signedAD)
require.NoError(t, err)

req := &proto.RegisterUsingAzureMethodRequest{
RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{
Token: azureTokenName,
HostID: "test-node",
Role: types.RoleBot,
PublicSSHKey: sshPubKey,
PublicTLSKey: tlsPubKey,
RemoteAddr: remoteAddr,
},
AttestedData: signedADBytes,
AccessToken: accessToken,
}
return req, nil
}, auth.WithAzureCerts([]*x509.Certificate{tlsConfig.Certificate}), auth.WithAzureVerifyFunc(mockVerifyToken(nil)), auth.WithAzureVMClientGetter(getVMClient))
require.NoError(t, err)
checkCertLoginIP(t, certs.TLS, remoteAddr)
})
}

func responseFromAWSIdentity(id iamjoin.AWSIdentity) string {
Expand Down
31 changes: 0 additions & 31 deletions lib/auth/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,6 @@ const (

MaxUserAgentLen = maxUserAgentLen
ForwardedTag = forwardedTag

AzureAccessTokenAudience = azureAccessTokenAudience
)

var (
Expand Down Expand Up @@ -307,10 +305,6 @@ func TrimUserAgent(userAgent string) string {
return trimUserAgent(userAgent)
}

func IsAllowedDomain(cn string, domains []string) bool {
return isAllowedDomain(cn, domains)
}

func GetSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) {
return getSnowflakeJWTParams(ctx, accountName, userName, publicKey)
}
Expand All @@ -336,31 +330,6 @@ func CheckHeaders(headers http.Header, challenge string, clock clockwork.Clock)
}

type GitHubManager = githubManager
type AttestedData = attestedData
type SignedAttestedData = signedAttestedData
type AzureRegisterOption = azureRegisterOption
type AzureRegisterConfig = azureRegisterConfig
type AzureVMClientGetter = vmClientGetter
type AzureVerifyTokenFunc = azureVerifyTokenFunc
type AccessTokenClaims = accessTokenClaims

func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.certificateAuthorities = certs
}
}

func WithAzureVerifyFunc(verify azureVerifyTokenFunc) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.verify = verify
}
}

func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption {
return func(cfg *AzureRegisterConfig) {
cfg.getVMClient = getVMClient
}
}

func (s *TLSServer) GRPCServer() *GRPCServer {
return s.grpcServer
Expand Down
18 changes: 17 additions & 1 deletion lib/auth/join/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ type AzureParams struct {
// ClientID is the client ID of the managed identity for Teleport to assume
// when authenticating a node.
ClientID string
// IMDSClient overrides the client used to fetch data from Azure IMDS.
IMDSClient AzureIMDSClient
// IssuerHTTPClient, if set, overrides the default HTTP client used to
// fetch the intermediate CA which issued the attested data document
// signing certificate. Only used when joining via the new join service.
IssuerHTTPClient utils.HTTPDoClient
}

// AzureIMDSClient is a client to Azure's IMDS.
type AzureIMDSClient interface {
IsAvailable(context.Context) bool
GetAttestedData(ctx context.Context, nonce string) ([]byte, error)
GetAccessToken(ctx context.Context, clientID string) (string, error)
}

// GitlabParams is the parameters specific to the gitlab join method.
Expand Down Expand Up @@ -875,7 +888,10 @@ func registerUsingAzureMethod(
ctx context.Context, client joinServiceClient, token string, hostKeys *newHostKeys, params RegisterParams,
) (*proto.Certs, error) {
certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) {
imds := azure.NewInstanceMetadataClient()
imds := params.AzureParams.IMDSClient
if imds == nil {
imds = azure.NewInstanceMetadataClient()
}
if !imds.IsAvailable(ctx) {
return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?")
}
Expand Down
Loading
Loading