From dc93c419c3437d312cae77ef62e15978fe4358c7 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 26 Nov 2025 11:26:21 -0800 Subject: [PATCH 1/2] [v18] azure join message types Backport #61128 to branch/v18 --- .../go/teleport/join/v1/joinservice.pb.go | 436 ++++++++++++++---- api/proto/teleport/join/v1/joinservice.proto | 41 ++ lib/join/internal/messages/messages.go | 45 ++ lib/join/joinv1/messages.go | 34 +- lib/join/joinv1/messages_azure.go | 72 +++ lib/join/joinv1/messages_test.go | 16 + 6 files changed, 545 insertions(+), 99 deletions(-) create mode 100644 lib/join/joinv1/messages_azure.go diff --git a/api/gen/proto/go/teleport/join/v1/joinservice.pb.go b/api/gen/proto/go/teleport/join/v1/joinservice.pb.go index f198a20c83b23..143a26cb1a3f5 100644 --- a/api/gen/proto/go/teleport/join/v1/joinservice.pb.go +++ b/api/gen/proto/go/teleport/join/v1/joinservice.pb.go @@ -93,7 +93,7 @@ func (x GivingUp_Reason) Number() protoreflect.EnumNumber { // Deprecated: Use GivingUp_Reason.Descriptor instead. func (GivingUp_Reason) EnumDescriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{24, 0} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{27, 0} } // ClientInit is the first message sent from the client during the join process, it @@ -1572,6 +1572,178 @@ func (x *TPMSolution) GetSolution() []byte { return nil } +// AzureInit is sent from the client in response to the ServerInit message for +// the Azure join method. +// +// The Azure method join flow is: +// 1. client->server: ClientInit +// 2. client<-server: ServerInit +// 3. client->server: AzureInit +// 4. client<-server: AzureChallenge +// 5. client->server: AzureChallengeSolution +// 6. client<-server: Result +type AzureInit struct { + state protoimpl.MessageState `protogen:"open.v1"` + // ClientParams holds parameters for the specific type of client trying to join. + ClientParams *ClientParams `protobuf:"bytes,1,opt,name=client_params,json=clientParams,proto3" json:"client_params,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AzureInit) Reset() { + *x = AzureInit{} + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[23] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AzureInit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AzureInit) ProtoMessage() {} + +func (x *AzureInit) ProtoReflect() protoreflect.Message { + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[23] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AzureInit.ProtoReflect.Descriptor instead. +func (*AzureInit) Descriptor() ([]byte, []int) { + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{23} +} + +func (x *AzureInit) GetClientParams() *ClientParams { + if x != nil { + return x.ClientParams + } + return nil +} + +// AzureChallenge is sent from the server in response to the AzureInit message from the client. +// The client is expected to respond with a AzureChallengeSolution. +type AzureChallenge struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Challenge is a a crypto-random string that should be included by the + // client in the challenge response message. + Challenge string `protobuf:"bytes,1,opt,name=challenge,proto3" json:"challenge,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AzureChallenge) Reset() { + *x = AzureChallenge{} + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[24] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AzureChallenge) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AzureChallenge) ProtoMessage() {} + +func (x *AzureChallenge) ProtoReflect() protoreflect.Message { + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[24] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AzureChallenge.ProtoReflect.Descriptor instead. +func (*AzureChallenge) Descriptor() ([]byte, []int) { + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{24} +} + +func (x *AzureChallenge) GetChallenge() string { + if x != nil { + return x.Challenge + } + return "" +} + +// AzureChallengeSolution must be sent from the client in response to the +// AzureChallenge message. +type AzureChallengeSolution struct { + state protoimpl.MessageState `protogen:"open.v1"` + // AttestedData is a signed JSON document from an Azure VM's attested data + // metadata endpoint used to prove the identity of a joining node. It must + // include the challenge string as the nonce. + AttestedData []byte `protobuf:"bytes,1,opt,name=attested_data,json=attestedData,proto3" json:"attested_data,omitempty"` + // Intermediate encodes the intermediate CAs that issued the leaf certificate + // used to sign the attested data document, in x509 DER format. + Intermediate []byte `protobuf:"bytes,2,opt,name=intermediate,proto3" json:"intermediate,omitempty"` + // AccessToken is a JWT signed by Azure, used to prove the identity of a + // joining node. + AccessToken string `protobuf:"bytes,3,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AzureChallengeSolution) Reset() { + *x = AzureChallengeSolution{} + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[25] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AzureChallengeSolution) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AzureChallengeSolution) ProtoMessage() {} + +func (x *AzureChallengeSolution) ProtoReflect() protoreflect.Message { + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[25] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AzureChallengeSolution.ProtoReflect.Descriptor instead. +func (*AzureChallengeSolution) Descriptor() ([]byte, []int) { + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{25} +} + +func (x *AzureChallengeSolution) GetAttestedData() []byte { + if x != nil { + return x.AttestedData + } + return nil +} + +func (x *AzureChallengeSolution) GetIntermediate() []byte { + if x != nil { + return x.Intermediate + } + return nil +} + +func (x *AzureChallengeSolution) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + // ChallengeSolution holds a solution to a challenge issued by the server. type ChallengeSolution struct { state protoimpl.MessageState `protogen:"open.v1"` @@ -1582,6 +1754,7 @@ type ChallengeSolution struct { // *ChallengeSolution_IamChallengeSolution // *ChallengeSolution_OracleChallengeSolution // *ChallengeSolution_TpmSolution + // *ChallengeSolution_AzureChallengeSolution Payload isChallengeSolution_Payload `protobuf_oneof:"payload"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -1589,7 +1762,7 @@ type ChallengeSolution struct { func (x *ChallengeSolution) Reset() { *x = ChallengeSolution{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[23] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[26] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1601,7 +1774,7 @@ func (x *ChallengeSolution) String() string { func (*ChallengeSolution) ProtoMessage() {} func (x *ChallengeSolution) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[23] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[26] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1614,7 +1787,7 @@ func (x *ChallengeSolution) ProtoReflect() protoreflect.Message { // Deprecated: Use ChallengeSolution.ProtoReflect.Descriptor instead. func (*ChallengeSolution) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{23} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{26} } func (x *ChallengeSolution) GetPayload() isChallengeSolution_Payload { @@ -1669,6 +1842,15 @@ func (x *ChallengeSolution) GetTpmSolution() *TPMSolution { return nil } +func (x *ChallengeSolution) GetAzureChallengeSolution() *AzureChallengeSolution { + if x != nil { + if x, ok := x.Payload.(*ChallengeSolution_AzureChallengeSolution); ok { + return x.AzureChallengeSolution + } + } + return nil +} + type isChallengeSolution_Payload interface { isChallengeSolution_Payload() } @@ -1693,6 +1875,10 @@ type ChallengeSolution_TpmSolution struct { TpmSolution *TPMSolution `protobuf:"bytes,5,opt,name=tpm_solution,json=tpmSolution,proto3,oneof"` } +type ChallengeSolution_AzureChallengeSolution struct { + AzureChallengeSolution *AzureChallengeSolution `protobuf:"bytes,6,opt,name=azure_challenge_solution,json=azureChallengeSolution,proto3,oneof"` +} + func (*ChallengeSolution_BoundKeypairChallengeSolution) isChallengeSolution_Payload() {} func (*ChallengeSolution_BoundKeypairRotationResponse) isChallengeSolution_Payload() {} @@ -1703,6 +1889,8 @@ func (*ChallengeSolution_OracleChallengeSolution) isChallengeSolution_Payload() func (*ChallengeSolution_TpmSolution) isChallengeSolution_Payload() {} +func (*ChallengeSolution_AzureChallengeSolution) isChallengeSolution_Payload() {} + // GivingUp should be sent by clients that fail to complete the join flow so // that the Auth service can log an informative error message. type GivingUp struct { @@ -1717,7 +1905,7 @@ type GivingUp struct { func (x *GivingUp) Reset() { *x = GivingUp{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[24] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[27] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1729,7 +1917,7 @@ func (x *GivingUp) String() string { func (*GivingUp) ProtoMessage() {} func (x *GivingUp) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[24] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[27] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1742,7 +1930,7 @@ func (x *GivingUp) ProtoReflect() protoreflect.Message { // Deprecated: Use GivingUp.ProtoReflect.Descriptor instead. func (*GivingUp) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{24} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{27} } func (x *GivingUp) GetReason() GivingUp_Reason { @@ -1774,6 +1962,7 @@ type JoinRequest struct { // *JoinRequest_OidcInit // *JoinRequest_OracleInit // *JoinRequest_TpmInit + // *JoinRequest_AzureInit Payload isJoinRequest_Payload `protobuf_oneof:"payload"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -1781,7 +1970,7 @@ type JoinRequest struct { func (x *JoinRequest) Reset() { *x = JoinRequest{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[25] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1793,7 +1982,7 @@ func (x *JoinRequest) String() string { func (*JoinRequest) ProtoMessage() {} func (x *JoinRequest) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[25] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1806,7 +1995,7 @@ func (x *JoinRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use JoinRequest.ProtoReflect.Descriptor instead. func (*JoinRequest) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{25} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{28} } func (x *JoinRequest) GetPayload() isJoinRequest_Payload { @@ -1906,6 +2095,15 @@ func (x *JoinRequest) GetTpmInit() *TPMInit { return nil } +func (x *JoinRequest) GetAzureInit() *AzureInit { + if x != nil { + if x, ok := x.Payload.(*JoinRequest_AzureInit); ok { + return x.AzureInit + } + } + return nil +} + type isJoinRequest_Payload interface { isJoinRequest_Payload() } @@ -1950,6 +2148,10 @@ type JoinRequest_TpmInit struct { TpmInit *TPMInit `protobuf:"bytes,10,opt,name=tpm_init,json=tpmInit,proto3,oneof"` } +type JoinRequest_AzureInit struct { + AzureInit *AzureInit `protobuf:"bytes,11,opt,name=azure_init,json=azureInit,proto3,oneof"` +} + func (*JoinRequest_ClientInit) isJoinRequest_Payload() {} func (*JoinRequest_TokenInit) isJoinRequest_Payload() {} @@ -1970,6 +2172,8 @@ func (*JoinRequest_OracleInit) isJoinRequest_Payload() {} func (*JoinRequest_TpmInit) isJoinRequest_Payload() {} +func (*JoinRequest_AzureInit) isJoinRequest_Payload() {} + // ServerInit is the first message sent from the server in response to the // ClientInit message. type ServerInit struct { @@ -1985,7 +2189,7 @@ type ServerInit struct { func (x *ServerInit) Reset() { *x = ServerInit{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[26] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1997,7 +2201,7 @@ func (x *ServerInit) String() string { func (*ServerInit) ProtoMessage() {} func (x *ServerInit) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[26] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2010,7 +2214,7 @@ func (x *ServerInit) ProtoReflect() protoreflect.Message { // Deprecated: Use ServerInit.ProtoReflect.Descriptor instead. func (*ServerInit) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{26} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{29} } func (x *ServerInit) GetJoinMethod() string { @@ -2037,6 +2241,7 @@ type Challenge struct { // *Challenge_IamChallenge // *Challenge_OracleChallenge // *Challenge_TpmEncryptedCredential + // *Challenge_AzureChallenge Payload isChallenge_Payload `protobuf_oneof:"payload"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache @@ -2044,7 +2249,7 @@ type Challenge struct { func (x *Challenge) Reset() { *x = Challenge{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[27] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2056,7 +2261,7 @@ func (x *Challenge) String() string { func (*Challenge) ProtoMessage() {} func (x *Challenge) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[27] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2069,7 +2274,7 @@ func (x *Challenge) ProtoReflect() protoreflect.Message { // Deprecated: Use Challenge.ProtoReflect.Descriptor instead. func (*Challenge) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{27} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{30} } func (x *Challenge) GetPayload() isChallenge_Payload { @@ -2124,6 +2329,15 @@ func (x *Challenge) GetTpmEncryptedCredential() *TPMEncryptedCredential { return nil } +func (x *Challenge) GetAzureChallenge() *AzureChallenge { + if x != nil { + if x, ok := x.Payload.(*Challenge_AzureChallenge); ok { + return x.AzureChallenge + } + } + return nil +} + type isChallenge_Payload interface { isChallenge_Payload() } @@ -2148,6 +2362,10 @@ type Challenge_TpmEncryptedCredential struct { TpmEncryptedCredential *TPMEncryptedCredential `protobuf:"bytes,5,opt,name=tpm_encrypted_credential,json=tpmEncryptedCredential,proto3,oneof"` } +type Challenge_AzureChallenge struct { + AzureChallenge *AzureChallenge `protobuf:"bytes,6,opt,name=azure_challenge,json=azureChallenge,proto3,oneof"` +} + func (*Challenge_BoundKeypairChallenge) isChallenge_Payload() {} func (*Challenge_BoundKeypairRotationRequest) isChallenge_Payload() {} @@ -2158,6 +2376,8 @@ func (*Challenge_OracleChallenge) isChallenge_Payload() {} func (*Challenge_TpmEncryptedCredential) isChallenge_Payload() {} +func (*Challenge_AzureChallenge) isChallenge_Payload() {} + // Result is the final message sent from the cluster back to the client, it // contains the result of the joining process including the assigned host ID // and issued certificates. @@ -2174,7 +2394,7 @@ type Result struct { func (x *Result) Reset() { *x = Result{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[28] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2186,7 +2406,7 @@ func (x *Result) String() string { func (*Result) ProtoMessage() {} func (x *Result) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[28] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2199,7 +2419,7 @@ func (x *Result) ProtoReflect() protoreflect.Message { // Deprecated: Use Result.ProtoReflect.Descriptor instead. func (*Result) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{28} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{31} } func (x *Result) GetPayload() isResult_Payload { @@ -2262,7 +2482,7 @@ type Certificates struct { func (x *Certificates) Reset() { *x = Certificates{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[29] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2274,7 +2494,7 @@ func (x *Certificates) String() string { func (*Certificates) ProtoMessage() {} func (x *Certificates) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[29] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2287,7 +2507,7 @@ func (x *Certificates) ProtoReflect() protoreflect.Message { // Deprecated: Use Certificates.ProtoReflect.Descriptor instead. func (*Certificates) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{29} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{32} } func (x *Certificates) GetTlsCert() []byte { @@ -2331,7 +2551,7 @@ type HostResult struct { func (x *HostResult) Reset() { *x = HostResult{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[30] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2343,7 +2563,7 @@ func (x *HostResult) String() string { func (*HostResult) ProtoMessage() {} func (x *HostResult) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[30] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2356,7 +2576,7 @@ func (x *HostResult) ProtoReflect() protoreflect.Message { // Deprecated: Use HostResult.ProtoReflect.Descriptor instead. func (*HostResult) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{30} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{33} } func (x *HostResult) GetCertificates() *Certificates { @@ -2386,7 +2606,7 @@ type BotResult struct { func (x *BotResult) Reset() { *x = BotResult{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[31] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2398,7 +2618,7 @@ func (x *BotResult) String() string { func (*BotResult) ProtoMessage() {} func (x *BotResult) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[31] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2411,7 +2631,7 @@ func (x *BotResult) ProtoReflect() protoreflect.Message { // Deprecated: Use BotResult.ProtoReflect.Descriptor instead. func (*BotResult) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{31} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{34} } func (x *BotResult) GetCertificates() *Certificates { @@ -2443,7 +2663,7 @@ type JoinResponse struct { func (x *JoinResponse) Reset() { *x = JoinResponse{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[32] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[35] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2455,7 +2675,7 @@ func (x *JoinResponse) String() string { func (*JoinResponse) ProtoMessage() {} func (x *JoinResponse) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[32] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[35] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2468,7 +2688,7 @@ func (x *JoinResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use JoinResponse.ProtoReflect.Descriptor instead. func (*JoinResponse) Descriptor() ([]byte, []int) { - return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{32} + return file_teleport_join_v1_joinservice_proto_rawDescGZIP(), []int{35} } func (x *JoinResponse) GetPayload() isJoinResponse_Payload { @@ -2551,7 +2771,7 @@ type ClientInit_ProxySuppliedParams struct { func (x *ClientInit_ProxySuppliedParams) Reset() { *x = ClientInit_ProxySuppliedParams{} - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[33] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[36] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -2563,7 +2783,7 @@ func (x *ClientInit_ProxySuppliedParams) String() string { func (*ClientInit_ProxySuppliedParams) ProtoMessage() {} func (x *ClientInit_ProxySuppliedParams) ProtoReflect() protoreflect.Message { - mi := &file_teleport_join_v1_joinservice_proto_msgTypes[33] + mi := &file_teleport_join_v1_joinservice_proto_msgTypes[36] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -2695,13 +2915,22 @@ const file_teleport_join_v1_joinservice_proto_rawDesc = "" + "\x0fcredential_blob\x18\x01 \x01(\fR\x0ecredentialBlob\x12\x16\n" + "\x06secret\x18\x02 \x01(\fR\x06secret\")\n" + "\vTPMSolution\x12\x1a\n" + - "\bsolution\x18\x01 \x01(\fR\bsolution\"\xa0\x04\n" + + "\bsolution\x18\x01 \x01(\fR\bsolution\"P\n" + + "\tAzureInit\x12C\n" + + "\rclient_params\x18\x01 \x01(\v2\x1e.teleport.join.v1.ClientParamsR\fclientParams\".\n" + + "\x0eAzureChallenge\x12\x1c\n" + + "\tchallenge\x18\x01 \x01(\tR\tchallenge\"\x84\x01\n" + + "\x16AzureChallengeSolution\x12#\n" + + "\rattested_data\x18\x01 \x01(\fR\fattestedData\x12\"\n" + + "\fintermediate\x18\x02 \x01(\fR\fintermediate\x12!\n" + + "\faccess_token\x18\x03 \x01(\tR\vaccessToken\"\x86\x05\n" + "\x11ChallengeSolution\x12z\n" + " bound_keypair_challenge_solution\x18\x01 \x01(\v2/.teleport.join.v1.BoundKeypairChallengeSolutionH\x00R\x1dboundKeypairChallengeSolution\x12w\n" + "\x1fbound_keypair_rotation_response\x18\x02 \x01(\v2..teleport.join.v1.BoundKeypairRotationResponseH\x00R\x1cboundKeypairRotationResponse\x12^\n" + "\x16iam_challenge_solution\x18\x03 \x01(\v2&.teleport.join.v1.IAMChallengeSolutionH\x00R\x14iamChallengeSolution\x12g\n" + "\x19oracle_challenge_solution\x18\x04 \x01(\v2).teleport.join.v1.OracleChallengeSolutionH\x00R\x17oracleChallengeSolution\x12B\n" + - "\ftpm_solution\x18\x05 \x01(\v2\x1d.teleport.join.v1.TPMSolutionH\x00R\vtpmSolutionB\t\n" + + "\ftpm_solution\x18\x05 \x01(\v2\x1d.teleport.join.v1.TPMSolutionH\x00R\vtpmSolution\x12d\n" + + "\x18azure_challenge_solution\x18\x06 \x01(\v2(.teleport.join.v1.AzureChallengeSolutionH\x00R\x16azureChallengeSolutionB\t\n" + "\apayload\"\xe9\x01\n" + "\bGivingUp\x129\n" + "\x06reason\x18\x01 \x01(\x0e2!.teleport.join.v1.GivingUp.ReasonR\x06reason\x12\x10\n" + @@ -2710,7 +2939,7 @@ const file_teleport_join_v1_joinservice_proto_rawDesc = "" + "\x12REASON_UNSPECIFIED\x10\x00\x12\"\n" + "\x1eREASON_UNSUPPORTED_JOIN_METHOD\x10\x01\x12#\n" + "\x1fREASON_UNSUPPORTED_MESSAGE_TYPE\x10\x02\x12$\n" + - " REASON_CHALLENGE_SOLUTION_FAILED\x10\x03\"\x8d\x05\n" + + " REASON_CHALLENGE_SOLUTION_FAILED\x10\x03\"\xcb\x05\n" + "\vJoinRequest\x12?\n" + "\vclient_init\x18\x01 \x01(\v2\x1c.teleport.join.v1.ClientInitH\x00R\n" + "clientInit\x12<\n" + @@ -2725,19 +2954,22 @@ const file_teleport_join_v1_joinservice_proto_rawDesc = "" + "\voracle_init\x18\t \x01(\v2\x1c.teleport.join.v1.OracleInitH\x00R\n" + "oracleInit\x126\n" + "\btpm_init\x18\n" + - " \x01(\v2\x19.teleport.join.v1.TPMInitH\x00R\atpmInitB\t\n" + + " \x01(\v2\x19.teleport.join.v1.TPMInitH\x00R\atpmInit\x12<\n" + + "\n" + + "azure_init\x18\v \x01(\v2\x1b.teleport.join.v1.AzureInitH\x00R\tazureInitB\t\n" + "\apayload\"i\n" + "\n" + "ServerInit\x12\x1f\n" + "\vjoin_method\x18\x01 \x01(\tR\n" + "joinMethod\x12:\n" + - "\x19signature_algorithm_suite\x18\x02 \x01(\tR\x17signatureAlgorithmSuite\"\xec\x03\n" + + "\x19signature_algorithm_suite\x18\x02 \x01(\tR\x17signatureAlgorithmSuite\"\xb9\x04\n" + "\tChallenge\x12a\n" + "\x17bound_keypair_challenge\x18\x01 \x01(\v2'.teleport.join.v1.BoundKeypairChallengeH\x00R\x15boundKeypairChallenge\x12t\n" + "\x1ebound_keypair_rotation_request\x18\x02 \x01(\v2-.teleport.join.v1.BoundKeypairRotationRequestH\x00R\x1bboundKeypairRotationRequest\x12E\n" + "\riam_challenge\x18\x03 \x01(\v2\x1e.teleport.join.v1.IAMChallengeH\x00R\fiamChallenge\x12N\n" + "\x10oracle_challenge\x18\x04 \x01(\v2!.teleport.join.v1.OracleChallengeH\x00R\x0foracleChallenge\x12d\n" + - "\x18tpm_encrypted_credential\x18\x05 \x01(\v2(.teleport.join.v1.TPMEncryptedCredentialH\x00R\x16tpmEncryptedCredentialB\t\n" + + "\x18tpm_encrypted_credential\x18\x05 \x01(\v2(.teleport.join.v1.TPMEncryptedCredentialH\x00R\x16tpmEncryptedCredential\x12K\n" + + "\x0fazure_challenge\x18\x06 \x01(\v2 .teleport.join.v1.AzureChallengeH\x00R\x0eazureChallengeB\t\n" + "\apayload\"\x92\x01\n" + "\x06Result\x12?\n" + "\vhost_result\x18\x01 \x01(\v2\x1c.teleport.join.v1.HostResultH\x00R\n" + @@ -2780,7 +3012,7 @@ func file_teleport_join_v1_joinservice_proto_rawDescGZIP() []byte { } var file_teleport_join_v1_joinservice_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_teleport_join_v1_joinservice_proto_msgTypes = make([]protoimpl.MessageInfo, 34) +var file_teleport_join_v1_joinservice_proto_msgTypes = make([]protoimpl.MessageInfo, 37) var file_teleport_join_v1_joinservice_proto_goTypes = []any{ (GivingUp_Reason)(0), // 0: teleport.join.v1.GivingUp.Reason (*ClientInit)(nil), // 1: teleport.join.v1.ClientInit @@ -2806,24 +3038,27 @@ var file_teleport_join_v1_joinservice_proto_goTypes = []any{ (*TPMInit)(nil), // 21: teleport.join.v1.TPMInit (*TPMEncryptedCredential)(nil), // 22: teleport.join.v1.TPMEncryptedCredential (*TPMSolution)(nil), // 23: teleport.join.v1.TPMSolution - (*ChallengeSolution)(nil), // 24: teleport.join.v1.ChallengeSolution - (*GivingUp)(nil), // 25: teleport.join.v1.GivingUp - (*JoinRequest)(nil), // 26: teleport.join.v1.JoinRequest - (*ServerInit)(nil), // 27: teleport.join.v1.ServerInit - (*Challenge)(nil), // 28: teleport.join.v1.Challenge - (*Result)(nil), // 29: teleport.join.v1.Result - (*Certificates)(nil), // 30: teleport.join.v1.Certificates - (*HostResult)(nil), // 31: teleport.join.v1.HostResult - (*BotResult)(nil), // 32: teleport.join.v1.BotResult - (*JoinResponse)(nil), // 33: teleport.join.v1.JoinResponse - (*ClientInit_ProxySuppliedParams)(nil), // 34: teleport.join.v1.ClientInit.ProxySuppliedParams - (*timestamppb.Timestamp)(nil), // 35: google.protobuf.Timestamp + (*AzureInit)(nil), // 24: teleport.join.v1.AzureInit + (*AzureChallenge)(nil), // 25: teleport.join.v1.AzureChallenge + (*AzureChallengeSolution)(nil), // 26: teleport.join.v1.AzureChallengeSolution + (*ChallengeSolution)(nil), // 27: teleport.join.v1.ChallengeSolution + (*GivingUp)(nil), // 28: teleport.join.v1.GivingUp + (*JoinRequest)(nil), // 29: teleport.join.v1.JoinRequest + (*ServerInit)(nil), // 30: teleport.join.v1.ServerInit + (*Challenge)(nil), // 31: teleport.join.v1.Challenge + (*Result)(nil), // 32: teleport.join.v1.Result + (*Certificates)(nil), // 33: teleport.join.v1.Certificates + (*HostResult)(nil), // 34: teleport.join.v1.HostResult + (*BotResult)(nil), // 35: teleport.join.v1.BotResult + (*JoinResponse)(nil), // 36: teleport.join.v1.JoinResponse + (*ClientInit_ProxySuppliedParams)(nil), // 37: teleport.join.v1.ClientInit.ProxySuppliedParams + (*timestamppb.Timestamp)(nil), // 38: google.protobuf.Timestamp } var file_teleport_join_v1_joinservice_proto_depIdxs = []int32{ - 34, // 0: teleport.join.v1.ClientInit.proxy_supplied_parameters:type_name -> teleport.join.v1.ClientInit.ProxySuppliedParams + 37, // 0: teleport.join.v1.ClientInit.proxy_supplied_parameters:type_name -> teleport.join.v1.ClientInit.ProxySuppliedParams 2, // 1: teleport.join.v1.HostParams.public_keys:type_name -> teleport.join.v1.PublicKeys 2, // 2: teleport.join.v1.BotParams.public_keys:type_name -> teleport.join.v1.PublicKeys - 35, // 3: teleport.join.v1.BotParams.expires:type_name -> google.protobuf.Timestamp + 38, // 3: teleport.join.v1.BotParams.expires:type_name -> google.protobuf.Timestamp 3, // 4: teleport.join.v1.ClientParams.host_params:type_name -> teleport.join.v1.HostParams 4, // 5: teleport.join.v1.ClientParams.bot_params:type_name -> teleport.join.v1.BotParams 5, // 6: teleport.join.v1.TokenInit.client_params:type_name -> teleport.join.v1.ClientParams @@ -2833,42 +3068,46 @@ var file_teleport_join_v1_joinservice_proto_depIdxs = []int32{ 5, // 10: teleport.join.v1.EC2Init.client_params:type_name -> teleport.join.v1.ClientParams 5, // 11: teleport.join.v1.OracleInit.client_params:type_name -> teleport.join.v1.ClientParams 5, // 12: teleport.join.v1.TPMInit.client_params:type_name -> teleport.join.v1.ClientParams - 10, // 13: teleport.join.v1.ChallengeSolution.bound_keypair_challenge_solution:type_name -> teleport.join.v1.BoundKeypairChallengeSolution - 12, // 14: teleport.join.v1.ChallengeSolution.bound_keypair_rotation_response:type_name -> teleport.join.v1.BoundKeypairRotationResponse - 16, // 15: teleport.join.v1.ChallengeSolution.iam_challenge_solution:type_name -> teleport.join.v1.IAMChallengeSolution - 20, // 16: teleport.join.v1.ChallengeSolution.oracle_challenge_solution:type_name -> teleport.join.v1.OracleChallengeSolution - 23, // 17: teleport.join.v1.ChallengeSolution.tpm_solution:type_name -> teleport.join.v1.TPMSolution - 0, // 18: teleport.join.v1.GivingUp.reason:type_name -> teleport.join.v1.GivingUp.Reason - 1, // 19: teleport.join.v1.JoinRequest.client_init:type_name -> teleport.join.v1.ClientInit - 6, // 20: teleport.join.v1.JoinRequest.token_init:type_name -> teleport.join.v1.TokenInit - 8, // 21: teleport.join.v1.JoinRequest.bound_keypair_init:type_name -> teleport.join.v1.BoundKeypairInit - 24, // 22: teleport.join.v1.JoinRequest.solution:type_name -> teleport.join.v1.ChallengeSolution - 14, // 23: teleport.join.v1.JoinRequest.iam_init:type_name -> teleport.join.v1.IAMInit - 25, // 24: teleport.join.v1.JoinRequest.giving_up:type_name -> teleport.join.v1.GivingUp - 17, // 25: teleport.join.v1.JoinRequest.ec2_init:type_name -> teleport.join.v1.EC2Init - 7, // 26: teleport.join.v1.JoinRequest.oidc_init:type_name -> teleport.join.v1.OIDCInit - 18, // 27: teleport.join.v1.JoinRequest.oracle_init:type_name -> teleport.join.v1.OracleInit - 21, // 28: teleport.join.v1.JoinRequest.tpm_init:type_name -> teleport.join.v1.TPMInit - 9, // 29: teleport.join.v1.Challenge.bound_keypair_challenge:type_name -> teleport.join.v1.BoundKeypairChallenge - 11, // 30: teleport.join.v1.Challenge.bound_keypair_rotation_request:type_name -> teleport.join.v1.BoundKeypairRotationRequest - 15, // 31: teleport.join.v1.Challenge.iam_challenge:type_name -> teleport.join.v1.IAMChallenge - 19, // 32: teleport.join.v1.Challenge.oracle_challenge:type_name -> teleport.join.v1.OracleChallenge - 22, // 33: teleport.join.v1.Challenge.tpm_encrypted_credential:type_name -> teleport.join.v1.TPMEncryptedCredential - 31, // 34: teleport.join.v1.Result.host_result:type_name -> teleport.join.v1.HostResult - 32, // 35: teleport.join.v1.Result.bot_result:type_name -> teleport.join.v1.BotResult - 30, // 36: teleport.join.v1.HostResult.certificates:type_name -> teleport.join.v1.Certificates - 30, // 37: teleport.join.v1.BotResult.certificates:type_name -> teleport.join.v1.Certificates - 13, // 38: teleport.join.v1.BotResult.bound_keypair_result:type_name -> teleport.join.v1.BoundKeypairResult - 27, // 39: teleport.join.v1.JoinResponse.init:type_name -> teleport.join.v1.ServerInit - 28, // 40: teleport.join.v1.JoinResponse.challenge:type_name -> teleport.join.v1.Challenge - 29, // 41: teleport.join.v1.JoinResponse.result:type_name -> teleport.join.v1.Result - 26, // 42: teleport.join.v1.JoinService.Join:input_type -> teleport.join.v1.JoinRequest - 33, // 43: teleport.join.v1.JoinService.Join:output_type -> teleport.join.v1.JoinResponse - 43, // [43:44] is the sub-list for method output_type - 42, // [42:43] is the sub-list for method input_type - 42, // [42:42] is the sub-list for extension type_name - 42, // [42:42] is the sub-list for extension extendee - 0, // [0:42] is the sub-list for field type_name + 5, // 13: teleport.join.v1.AzureInit.client_params:type_name -> teleport.join.v1.ClientParams + 10, // 14: teleport.join.v1.ChallengeSolution.bound_keypair_challenge_solution:type_name -> teleport.join.v1.BoundKeypairChallengeSolution + 12, // 15: teleport.join.v1.ChallengeSolution.bound_keypair_rotation_response:type_name -> teleport.join.v1.BoundKeypairRotationResponse + 16, // 16: teleport.join.v1.ChallengeSolution.iam_challenge_solution:type_name -> teleport.join.v1.IAMChallengeSolution + 20, // 17: teleport.join.v1.ChallengeSolution.oracle_challenge_solution:type_name -> teleport.join.v1.OracleChallengeSolution + 23, // 18: teleport.join.v1.ChallengeSolution.tpm_solution:type_name -> teleport.join.v1.TPMSolution + 26, // 19: teleport.join.v1.ChallengeSolution.azure_challenge_solution:type_name -> teleport.join.v1.AzureChallengeSolution + 0, // 20: teleport.join.v1.GivingUp.reason:type_name -> teleport.join.v1.GivingUp.Reason + 1, // 21: teleport.join.v1.JoinRequest.client_init:type_name -> teleport.join.v1.ClientInit + 6, // 22: teleport.join.v1.JoinRequest.token_init:type_name -> teleport.join.v1.TokenInit + 8, // 23: teleport.join.v1.JoinRequest.bound_keypair_init:type_name -> teleport.join.v1.BoundKeypairInit + 27, // 24: teleport.join.v1.JoinRequest.solution:type_name -> teleport.join.v1.ChallengeSolution + 14, // 25: teleport.join.v1.JoinRequest.iam_init:type_name -> teleport.join.v1.IAMInit + 28, // 26: teleport.join.v1.JoinRequest.giving_up:type_name -> teleport.join.v1.GivingUp + 17, // 27: teleport.join.v1.JoinRequest.ec2_init:type_name -> teleport.join.v1.EC2Init + 7, // 28: teleport.join.v1.JoinRequest.oidc_init:type_name -> teleport.join.v1.OIDCInit + 18, // 29: teleport.join.v1.JoinRequest.oracle_init:type_name -> teleport.join.v1.OracleInit + 21, // 30: teleport.join.v1.JoinRequest.tpm_init:type_name -> teleport.join.v1.TPMInit + 24, // 31: teleport.join.v1.JoinRequest.azure_init:type_name -> teleport.join.v1.AzureInit + 9, // 32: teleport.join.v1.Challenge.bound_keypair_challenge:type_name -> teleport.join.v1.BoundKeypairChallenge + 11, // 33: teleport.join.v1.Challenge.bound_keypair_rotation_request:type_name -> teleport.join.v1.BoundKeypairRotationRequest + 15, // 34: teleport.join.v1.Challenge.iam_challenge:type_name -> teleport.join.v1.IAMChallenge + 19, // 35: teleport.join.v1.Challenge.oracle_challenge:type_name -> teleport.join.v1.OracleChallenge + 22, // 36: teleport.join.v1.Challenge.tpm_encrypted_credential:type_name -> teleport.join.v1.TPMEncryptedCredential + 25, // 37: teleport.join.v1.Challenge.azure_challenge:type_name -> teleport.join.v1.AzureChallenge + 34, // 38: teleport.join.v1.Result.host_result:type_name -> teleport.join.v1.HostResult + 35, // 39: teleport.join.v1.Result.bot_result:type_name -> teleport.join.v1.BotResult + 33, // 40: teleport.join.v1.HostResult.certificates:type_name -> teleport.join.v1.Certificates + 33, // 41: teleport.join.v1.BotResult.certificates:type_name -> teleport.join.v1.Certificates + 13, // 42: teleport.join.v1.BotResult.bound_keypair_result:type_name -> teleport.join.v1.BoundKeypairResult + 30, // 43: teleport.join.v1.JoinResponse.init:type_name -> teleport.join.v1.ServerInit + 31, // 44: teleport.join.v1.JoinResponse.challenge:type_name -> teleport.join.v1.Challenge + 32, // 45: teleport.join.v1.JoinResponse.result:type_name -> teleport.join.v1.Result + 29, // 46: teleport.join.v1.JoinService.Join:input_type -> teleport.join.v1.JoinRequest + 36, // 47: teleport.join.v1.JoinService.Join:output_type -> teleport.join.v1.JoinResponse + 47, // [47:48] is the sub-list for method output_type + 46, // [46:47] is the sub-list for method input_type + 46, // [46:46] is the sub-list for extension type_name + 46, // [46:46] is the sub-list for extension extendee + 0, // [0:46] is the sub-list for field type_name } func init() { file_teleport_join_v1_joinservice_proto_init() } @@ -2886,14 +3125,15 @@ func file_teleport_join_v1_joinservice_proto_init() { (*TPMInit_EkCert)(nil), (*TPMInit_EkKey)(nil), } - file_teleport_join_v1_joinservice_proto_msgTypes[23].OneofWrappers = []any{ + file_teleport_join_v1_joinservice_proto_msgTypes[26].OneofWrappers = []any{ (*ChallengeSolution_BoundKeypairChallengeSolution)(nil), (*ChallengeSolution_BoundKeypairRotationResponse)(nil), (*ChallengeSolution_IamChallengeSolution)(nil), (*ChallengeSolution_OracleChallengeSolution)(nil), (*ChallengeSolution_TpmSolution)(nil), + (*ChallengeSolution_AzureChallengeSolution)(nil), } - file_teleport_join_v1_joinservice_proto_msgTypes[25].OneofWrappers = []any{ + file_teleport_join_v1_joinservice_proto_msgTypes[28].OneofWrappers = []any{ (*JoinRequest_ClientInit)(nil), (*JoinRequest_TokenInit)(nil), (*JoinRequest_BoundKeypairInit)(nil), @@ -2904,20 +3144,22 @@ func file_teleport_join_v1_joinservice_proto_init() { (*JoinRequest_OidcInit)(nil), (*JoinRequest_OracleInit)(nil), (*JoinRequest_TpmInit)(nil), + (*JoinRequest_AzureInit)(nil), } - file_teleport_join_v1_joinservice_proto_msgTypes[27].OneofWrappers = []any{ + file_teleport_join_v1_joinservice_proto_msgTypes[30].OneofWrappers = []any{ (*Challenge_BoundKeypairChallenge)(nil), (*Challenge_BoundKeypairRotationRequest)(nil), (*Challenge_IamChallenge)(nil), (*Challenge_OracleChallenge)(nil), (*Challenge_TpmEncryptedCredential)(nil), + (*Challenge_AzureChallenge)(nil), } - file_teleport_join_v1_joinservice_proto_msgTypes[28].OneofWrappers = []any{ + file_teleport_join_v1_joinservice_proto_msgTypes[31].OneofWrappers = []any{ (*Result_HostResult)(nil), (*Result_BotResult)(nil), } - file_teleport_join_v1_joinservice_proto_msgTypes[31].OneofWrappers = []any{} - file_teleport_join_v1_joinservice_proto_msgTypes[32].OneofWrappers = []any{ + file_teleport_join_v1_joinservice_proto_msgTypes[34].OneofWrappers = []any{} + file_teleport_join_v1_joinservice_proto_msgTypes[35].OneofWrappers = []any{ (*JoinResponse_Init)(nil), (*JoinResponse_Challenge)(nil), (*JoinResponse_Result)(nil), @@ -2928,7 +3170,7 @@ func file_teleport_join_v1_joinservice_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_teleport_join_v1_joinservice_proto_rawDesc), len(file_teleport_join_v1_joinservice_proto_rawDesc)), NumEnums: 1, - NumMessages: 34, + NumMessages: 37, NumExtensions: 0, NumServices: 1, }, diff --git a/api/proto/teleport/join/v1/joinservice.proto b/api/proto/teleport/join/v1/joinservice.proto index bab3ed7df8429..2bc620a2ac6ac 100644 --- a/api/proto/teleport/join/v1/joinservice.proto +++ b/api/proto/teleport/join/v1/joinservice.proto @@ -349,6 +349,44 @@ message TPMSolution { bytes solution = 1; } +// AzureInit is sent from the client in response to the ServerInit message for +// the Azure join method. +// +// The Azure method join flow is: +// 1. client->server: ClientInit +// 2. client<-server: ServerInit +// 3. client->server: AzureInit +// 4. client<-server: AzureChallenge +// 5. client->server: AzureChallengeSolution +// 6. client<-server: Result +message AzureInit { + // ClientParams holds parameters for the specific type of client trying to join. + ClientParams client_params = 1; +} + +// AzureChallenge is sent from the server in response to the AzureInit message from the client. +// The client is expected to respond with a AzureChallengeSolution. +message AzureChallenge { + // Challenge is a a crypto-random string that should be included by the + // client in the challenge response message. + string challenge = 1; +} + +// AzureChallengeSolution must be sent from the client in response to the +// AzureChallenge message. +message AzureChallengeSolution { + // AttestedData is a signed JSON document from an Azure VM's attested data + // metadata endpoint used to prove the identity of a joining node. It must + // include the challenge string as the nonce. + bytes attested_data = 1; + // Intermediate encodes the intermediate CAs that issued the leaf certificate + // used to sign the attested data document, in x509 DER format. + bytes intermediate = 2; + // AccessToken is a JWT signed by Azure, used to prove the identity of a + // joining node. + string access_token = 3; +} + // ChallengeSolution holds a solution to a challenge issued by the server. message ChallengeSolution { oneof payload { @@ -357,6 +395,7 @@ message ChallengeSolution { IAMChallengeSolution iam_challenge_solution = 3; OracleChallengeSolution oracle_challenge_solution = 4; TPMSolution tpm_solution = 5; + AzureChallengeSolution azure_challenge_solution = 6; } } @@ -396,6 +435,7 @@ message JoinRequest { OIDCInit oidc_init = 8; OracleInit oracle_init = 9; TPMInit tpm_init = 10; + AzureInit azure_init = 11; } } @@ -417,6 +457,7 @@ message Challenge { IAMChallenge iam_challenge = 3; OracleChallenge oracle_challenge = 4; TPMEncryptedCredential tpm_encrypted_credential = 5; + AzureChallenge azure_challenge = 6; } } diff --git a/lib/join/internal/messages/messages.go b/lib/join/internal/messages/messages.go index 26f9392aa2b20..00444ab14be4c 100644 --- a/lib/join/internal/messages/messages.go +++ b/lib/join/internal/messages/messages.go @@ -468,6 +468,51 @@ type TPMSolution struct { Solution []byte } +// AzureInit is sent from the client in response to the ServerInit message for +// the Azure join method. +// +// The Azure method join flow is: +// 1. client->server: ClientInit +// 2. client<-server: ServerInit +// 3. client->server: AzureInit +// 4. client<-server: AzureChallenge +// 5. client->server: AzureChallengeSolution +// 6. client<-server: Result +type AzureInit struct { + embedRequest + + // ClientParams holds parameters for the specific type of client trying to join. + ClientParams ClientParams +} + +// AzureChallenge is sent from the server in response to the AzureInit message +// from the client. The client is expected to respond with a +// AzureChallengeSolution. +type AzureChallenge struct { + embedResponse + + // Challenge is a a crypto-random string that should be included by the + // client in the AzureChallengeSolution message. + Challenge string +} + +// AzureChallenge message. +// AzureChallengeSolution must be sent from the client in response to the +type AzureChallengeSolution struct { + embedRequest + + // AttestedData is a signed JSON document from an Azure VM's attested data + // metadata endpoint used to prove the identity of a joining node. It must + // include the challenge string as the nonce. + AttestedData []byte + // Intermediate encodes the intermediate CAs that issued the leaf certificate + // used to sign the attested data document, in x509 DER format. + Intermediate []byte + // AccessToken is a JWT signed by Azure, used to prove the identity of a + // joining node. + AccessToken string +} + // Response is implemented by all join response messages. type Response interface { isResponse() diff --git a/lib/join/joinv1/messages.go b/lib/join/joinv1/messages.go index b884dbab4a12b..a4649a366f975 100644 --- a/lib/join/joinv1/messages.go +++ b/lib/join/joinv1/messages.go @@ -44,6 +44,8 @@ func requestToMessage(req *joinv1.JoinRequest) (messages.Request, error) { return oracleInitToMessage(msg.OracleInit) case *joinv1.JoinRequest_TpmInit: return tpmInitToMessage(msg.TpmInit) + case *joinv1.JoinRequest_AzureInit: + return azureInitToMessage(msg.AzureInit) case *joinv1.JoinRequest_Solution: return challengeSolutionToMessage(msg.Solution) case *joinv1.JoinRequest_GivingUp: @@ -133,11 +135,22 @@ func requestFromMessage(msg messages.Request) (*joinv1.JoinRequest, error) { TpmInit: tpmInit, }, }, nil + case *messages.AzureInit: + azureInit, err := azureInitFromMessage(typedMsg) + if err != nil { + return nil, trace.Wrap(err) + } + return &joinv1.JoinRequest{ + Payload: &joinv1.JoinRequest_AzureInit{ + AzureInit: azureInit, + }, + }, nil case *messages.BoundKeypairChallengeSolution, *messages.BoundKeypairRotationResponse, *messages.IAMChallengeSolution, *messages.OracleChallengeSolution, - *messages.TPMSolution: + *messages.TPMSolution, + *messages.AzureChallengeSolution: solution, err := challengeSolutionFromMessage(typedMsg) if err != nil { return nil, trace.Wrap(err) @@ -307,6 +320,8 @@ func challengeSolutionToMessage(req *joinv1.ChallengeSolution) (messages.Request return oracleChallengeSolutionToMessage(payload.OracleChallengeSolution), nil case *joinv1.ChallengeSolution_TpmSolution: return tpmSolutionToMessage(payload.TpmSolution), nil + case *joinv1.ChallengeSolution_AzureChallengeSolution: + return azureChallengeSolutionToMessage(payload.AzureChallengeSolution), nil default: return nil, trace.BadParameter("unrecognized challenge solution message type %T", payload) } @@ -344,6 +359,12 @@ func challengeSolutionFromMessage(msg messages.Request) (*joinv1.ChallengeSoluti TpmSolution: tpmSolutionFromMessage(typedMsg), }, }, nil + case *messages.AzureChallengeSolution: + return &joinv1.ChallengeSolution{ + Payload: &joinv1.ChallengeSolution_AzureChallengeSolution{ + AzureChallengeSolution: azureChallengeSolutionFromMessage(typedMsg), + }, + }, nil default: return nil, trace.BadParameter("unrecognized challenge solution message type %T", msg) } @@ -377,7 +398,8 @@ func responseFromMessage(msg messages.Response) (*joinv1.JoinResponse, error) { *messages.BoundKeypairRotationRequest, *messages.IAMChallenge, *messages.OracleChallenge, - *messages.TPMEncryptedCredential: + *messages.TPMEncryptedCredential, + *messages.AzureChallenge: challenge, err := challengeFromMessage(msg) if err != nil { return nil, trace.Wrap(err) @@ -442,6 +464,8 @@ func challengeToMessage(resp *joinv1.Challenge) (messages.Response, error) { return oracleChallengeToMessage(payload.OracleChallenge), nil case *joinv1.Challenge_TpmEncryptedCredential: return tpmEncryptedCredentialToMessage(payload.TpmEncryptedCredential), nil + case *joinv1.Challenge_AzureChallenge: + return azureChallengeToMessage(payload.AzureChallenge), nil default: return nil, trace.BadParameter("unrecognized challenge payload type %T", payload) } @@ -479,6 +503,12 @@ func challengeFromMessage(resp messages.Response) (*joinv1.Challenge, error) { TpmEncryptedCredential: tpmEncryptedCredentialFromMessage(msg), }, }, nil + case *messages.AzureChallenge: + return &joinv1.Challenge{ + Payload: &joinv1.Challenge_AzureChallenge{ + AzureChallenge: azureChallengeFromMessage(msg), + }, + }, nil default: return nil, trace.BadParameter("unrecognized challenge message type %T", msg) } diff --git a/lib/join/joinv1/messages_azure.go b/lib/join/joinv1/messages_azure.go new file mode 100644 index 0000000000000..8fa1872e6b8c4 --- /dev/null +++ b/lib/join/joinv1/messages_azure.go @@ -0,0 +1,72 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package joinv1 + +import ( + "github.com/gravitational/trace" + + joinv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/join/v1" + "github.com/gravitational/teleport/lib/join/internal/messages" +) + +func azureInitToMessage(req *joinv1.AzureInit) (*messages.AzureInit, error) { + clientParams, err := clientParamsToMessage(req.GetClientParams()) + if err != nil { + return nil, trace.Wrap(err) + } + return &messages.AzureInit{ + ClientParams: clientParams, + }, nil +} + +func azureInitFromMessage(msg *messages.AzureInit) (*joinv1.AzureInit, error) { + clientParams, err := clientParamsFromMessage(msg.ClientParams) + if err != nil { + return nil, trace.Wrap(err) + } + return &joinv1.AzureInit{ + ClientParams: clientParams, + }, nil +} + +func azureChallengeToMessage(req *joinv1.AzureChallenge) *messages.AzureChallenge { + return &messages.AzureChallenge{ + Challenge: req.GetChallenge(), + } +} + +func azureChallengeFromMessage(msg *messages.AzureChallenge) *joinv1.AzureChallenge { + return &joinv1.AzureChallenge{ + Challenge: msg.Challenge, + } +} + +func azureChallengeSolutionToMessage(req *joinv1.AzureChallengeSolution) *messages.AzureChallengeSolution { + return &messages.AzureChallengeSolution{ + AttestedData: req.GetAttestedData(), + Intermediate: req.GetIntermediate(), + AccessToken: req.GetAccessToken(), + } +} + +func azureChallengeSolutionFromMessage(msg *messages.AzureChallengeSolution) *joinv1.AzureChallengeSolution { + return &joinv1.AzureChallengeSolution{ + AttestedData: msg.AttestedData, + Intermediate: msg.Intermediate, + AccessToken: msg.AccessToken, + } +} diff --git a/lib/join/joinv1/messages_test.go b/lib/join/joinv1/messages_test.go index 334e025509511..3b0999ab091f4 100644 --- a/lib/join/joinv1/messages_test.go +++ b/lib/join/joinv1/messages_test.go @@ -88,6 +88,12 @@ func TestRequestToMessage(t *testing.T) { Payload: &joinv1.JoinRequest_TpmInit{}, }, }, + { + desc: "empty AzureInit", + req: &joinv1.JoinRequest{ + Payload: &joinv1.JoinRequest_AzureInit{}, + }, + }, { desc: "empty HostParams", req: &joinv1.JoinRequest{ @@ -168,6 +174,16 @@ func TestRequestToMessage(t *testing.T) { }, }, }, + { + desc: "empty AzureSolution", + req: &joinv1.JoinRequest{ + Payload: &joinv1.JoinRequest_Solution{ + Solution: &joinv1.ChallengeSolution{ + Payload: &joinv1.ChallengeSolution_AzureChallengeSolution{}, + }, + }, + }, + }, { desc: "empty GivingUp", req: &joinv1.JoinRequest{ From f1a8385165cc3170007af17a5f3889e90c370569 Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 10 Dec 2025 10:40:33 -0800 Subject: [PATCH 2/2] [v18] Azure method support in new join service Backport #61129 to branch/v18 --- lib/auth/auth.go | 4 + lib/auth/bot_test.go | 92 --- lib/auth/export_test.go | 31 - lib/auth/join/join.go | 18 +- lib/auth/join_azure_legacy.go | 119 ++++ .../join_azure.go => join/azurejoin/azure.go} | 368 ++++++------ lib/{auth => join/azurejoin}/azure_certs.go | 10 +- .../azurejoin}/azure_certs_test.go | 6 +- .../azurejoin}/join_azure_test.go | 546 ++++++++++++------ lib/join/joinclient/join.go | 3 + lib/join/joinclient/join_azure.go | 135 +++++ lib/join/server.go | 4 + lib/join/server_azure.go | 118 ++++ 13 files changed, 964 insertions(+), 490 deletions(-) create mode 100644 lib/auth/join_azure_legacy.go rename lib/{auth/join_azure.go => join/azurejoin/azure.go} (62%) rename lib/{auth => join/azurejoin}/azure_certs.go (98%) rename lib/{auth => join/azurejoin}/azure_certs_test.go (92%) rename lib/{auth => join/azurejoin}/join_azure_test.go (65%) create mode 100644 lib/join/joinclient/join_azure.go create mode 100644 lib/join/server_azure.go diff --git a/lib/auth/auth.go b/lib/auth/auth.go index a8acdc1631dc9..ceda96b47cba0 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -115,6 +115,7 @@ import ( iterstream "github.com/gravitational/teleport/lib/itertools/stream" "github.com/gravitational/teleport/lib/join" "github.com/gravitational/teleport/lib/join/azuredevops" + "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/bitbucket" joinboundkeypair "github.com/gravitational/teleport/lib/join/boundkeypair" "github.com/gravitational/teleport/lib/join/circleci" @@ -1329,6 +1330,9 @@ type Server struct { // override the implementation used in tests. env0IDTokenValidator join.Env0TokenValidator + // azureJoinConfig holds configuration for the Azure join method. + azureJoinConfig *azurejoin.AzureJoinConfig + // loadAllCAs tells tsh to load the host CAs for all clusters when trying to ssh into a node. loadAllCAs bool diff --git a/lib/auth/bot_test.go b/lib/auth/bot_test.go index 5bb8bef0f816b..fbcd010fd723d 100644 --- a/lib/auth/bot_test.go +++ b/lib/auth/bot_test.go @@ -23,10 +23,6 @@ import ( "context" "crypto" "crypto/tls" - "crypto/x509" - "encoding/base64" - "encoding/json" - "encoding/pem" "errors" "fmt" "io" @@ -36,10 +32,8 @@ import ( "text/template" "time" - "github.com/digitorus/pkcs7" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -58,16 +52,13 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authclient" "github.com/gravitational/teleport/lib/auth/authtest" "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/auth/testauthority" - "github.com/gravitational/teleport/lib/cloud/azure" libevents "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/eventstest" - "github.com/gravitational/teleport/lib/fixtures" "github.com/gravitational/teleport/lib/join/iamjoin" "github.com/gravitational/teleport/lib/join/joinclient" "github.com/gravitational/teleport/lib/kube/token" @@ -744,89 +735,6 @@ func TestRegisterBot_RemoteAddr(t *testing.T) { require.NoError(t, err) checkCertLoginIP(t, certs.TLS, remoteAddr) }) - - t.Run("Azure method", func(t *testing.T) { - subID := uuid.NewString() - resourceGroup := "rg" - rsID := vmResourceID(subID, resourceGroup, "test-vm") - vmID := "vmID" - - accessToken, err := makeToken(rsID, "", a.GetClock().Now()) - require.NoError(t, err) - - // add token to auth server - azureTokenName := "azure-test-token" - azureToken, err := types.NewProvisionTokenFromSpec( - azureTokenName, - time.Now().Add(time.Minute), - types.ProvisionTokenSpecV2{ - Roles: []types.SystemRole{types.RoleBot}, - Azure: &types.ProvisionTokenSpecV2Azure{Allow: []*types.ProvisionTokenSpecV2Azure_Rule{{Subscription: subID}}}, - BotName: botName, - JoinMethod: types.JoinMethodAzure, - }) - require.NoError(t, err) - require.NoError(t, a.UpsertToken(ctx, azureToken)) - - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - rsID: { - ID: rsID, - Name: "test-vm", - Subscription: subID, - ResourceGroup: resourceGroup, - VMID: vmID, - }, - }, - } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - subID: vmClient, - }) - - tlsConfig, err := fixtures.LocalTLSConfig() - require.NoError(t, err) - - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) - require.NoError(t, err) - - certs, err := a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - ad := auth.AttestedData{ - Nonce: challenge, - SubscriptionID: subID, - ID: vmID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) - require.NoError(t, err) - - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: azureTokenName, - HostID: "test-node", - Role: types.RoleBot, - PublicSSHKey: sshPubKey, - PublicTLSKey: tlsPubKey, - RemoteAddr: remoteAddr, - }, - AttestedData: signedADBytes, - AccessToken: accessToken, - } - return req, nil - }, auth.WithAzureCerts([]*x509.Certificate{tlsConfig.Certificate}), auth.WithAzureVerifyFunc(mockVerifyToken(nil)), auth.WithAzureVMClientGetter(getVMClient)) - require.NoError(t, err) - checkCertLoginIP(t, certs.TLS, remoteAddr) - }) } func responseFromAWSIdentity(id iamjoin.AWSIdentity) string { diff --git a/lib/auth/export_test.go b/lib/auth/export_test.go index a2eb46bc3a737..3d0dee5ce4e09 100644 --- a/lib/auth/export_test.go +++ b/lib/auth/export_test.go @@ -70,8 +70,6 @@ const ( MaxUserAgentLen = maxUserAgentLen ForwardedTag = forwardedTag - - AzureAccessTokenAudience = azureAccessTokenAudience ) var ( @@ -307,10 +305,6 @@ func TrimUserAgent(userAgent string) string { return trimUserAgent(userAgent) } -func IsAllowedDomain(cn string, domains []string) bool { - return isAllowedDomain(cn, domains) -} - func GetSnowflakeJWTParams(ctx context.Context, accountName, userName string, publicKey []byte) (string, string) { return getSnowflakeJWTParams(ctx, accountName, userName, publicKey) } @@ -336,31 +330,6 @@ func CheckHeaders(headers http.Header, challenge string, clock clockwork.Clock) } type GitHubManager = githubManager -type AttestedData = attestedData -type SignedAttestedData = signedAttestedData -type AzureRegisterOption = azureRegisterOption -type AzureRegisterConfig = azureRegisterConfig -type AzureVMClientGetter = vmClientGetter -type AzureVerifyTokenFunc = azureVerifyTokenFunc -type AccessTokenClaims = accessTokenClaims - -func WithAzureCerts(certs []*x509.Certificate) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.certificateAuthorities = certs - } -} - -func WithAzureVerifyFunc(verify azureVerifyTokenFunc) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.verify = verify - } -} - -func WithAzureVMClientGetter(getVMClient vmClientGetter) AzureRegisterOption { - return func(cfg *AzureRegisterConfig) { - cfg.getVMClient = getVMClient - } -} func (s *TLSServer) GRPCServer() *GRPCServer { return s.grpcServer diff --git a/lib/auth/join/join.go b/lib/auth/join/join.go index bd286b53928cb..d5fae28d72cf1 100644 --- a/lib/auth/join/join.go +++ b/lib/auth/join/join.go @@ -78,6 +78,19 @@ type AzureParams struct { // ClientID is the client ID of the managed identity for Teleport to assume // when authenticating a node. ClientID string + // IMDSClient overrides the client used to fetch data from Azure IMDS. + IMDSClient AzureIMDSClient + // IssuerHTTPClient, if set, overrides the default HTTP client used to + // fetch the intermediate CA which issued the attested data document + // signing certificate. Only used when joining via the new join service. + IssuerHTTPClient utils.HTTPDoClient +} + +// AzureIMDSClient is a client to Azure's IMDS. +type AzureIMDSClient interface { + IsAvailable(context.Context) bool + GetAttestedData(ctx context.Context, nonce string) ([]byte, error) + GetAccessToken(ctx context.Context, clientID string) (string, error) } // GitlabParams is the parameters specific to the gitlab join method. @@ -875,7 +888,10 @@ func registerUsingAzureMethod( ctx context.Context, client joinServiceClient, token string, hostKeys *newHostKeys, params RegisterParams, ) (*proto.Certs, error) { certs, err := client.RegisterUsingAzureMethod(ctx, func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - imds := azure.NewInstanceMetadataClient() + imds := params.AzureParams.IMDSClient + if imds == nil { + imds = azure.NewInstanceMetadataClient() + } if !imds.IsAvailable(ctx) { return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") } diff --git a/lib/auth/join_azure_legacy.go b/lib/auth/join_azure_legacy.go new file mode 100644 index 0000000000000..9032f084e6aeb --- /dev/null +++ b/lib/auth/join_azure_legacy.go @@ -0,0 +1,119 @@ +/* + * Teleport + * Copyright (C) 2023 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package auth + +import ( + "context" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/legacyjoin" +) + +// RegisterUsingAzureMethod registers the caller using the Azure join method +// and returns signed certs to join the cluster. +// +// The caller must provide a ChallengeResponseFunc which returns a +// *proto.RegisterUsingAzureMethodRequest with a signed attested data document +// including the challenge as a nonce. +// +// TODO(nklaassen): DELETE IN 20 when removing the legacy join service. +func (a *Server) RegisterUsingAzureMethod( + ctx context.Context, + challengeResponse client.RegisterAzureChallengeResponseFunc, +) (certs *proto.Certs, err error) { + var provisionToken types.ProvisionToken + var joinRequest *types.RegisterUsingTokenRequest + defer func() { + // Emit a log message and audit event on join failure. + if err != nil { + a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest) + } + }() + + if legacyjoin.Disabled() { + return nil, trace.Wrap(legacyjoin.ErrDisabled) + } + + challenge, err := azurejoin.GenerateAzureChallenge() + if err != nil { + return nil, trace.Wrap(err) + } + req, err := challengeResponse(challenge) + if err != nil { + return nil, trace.Wrap(err) + } + joinRequest = req.RegisterUsingTokenRequest + + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest) + if err != nil { + return nil, trace.Wrap(err) + } + if provisionToken.GetJoinMethod() != types.JoinMethodAzure { + return nil, trace.AccessDenied("this token does not support the Azure join method") + } + + ptv2, ok := provisionToken.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.Wrap(err, "Azure join method only supports ProvisionTokenV2, got %T", provisionToken) + } + + joinAttrs, err := azurejoin.CheckAzureRequest(ctx, azurejoin.CheckAzureRequestParams{ + AzureJoinConfig: a.GetAzureJoinConfig(), + Token: ptv2, + Challenge: challenge, + AttestedData: req.AttestedData, + AccessToken: req.AccessToken, + Logger: a.logger, + Clock: a.GetClock(), + }) + if err != nil { + return nil, trace.Wrap(err, "checking Azure challenge response") + } + + if req.RegisterUsingTokenRequest.Role == types.RoleBot { + params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{ + Azure: joinAttrs, + }) + certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params) + return certs, trace.Wrap(err) + } + params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/) + certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params) + return certs, trace.Wrap(err) +} + +// GetAzureJoinConfig gets configuration options for azure joining. +func (a *Server) GetAzureJoinConfig() *azurejoin.AzureJoinConfig { + return a.azureJoinConfig +} + +// SetAzureJoinConfig sets configuration options for azure joining. +func (a *Server) SetAzureJoinConfig(c *azurejoin.AzureJoinConfig) { + a.azureJoinConfig = c +} diff --git a/lib/auth/join_azure.go b/lib/join/azurejoin/azure.go similarity index 62% rename from lib/auth/join_azure.go rename to lib/join/azurejoin/azure.go index ff54227baf252..7d2f5e7957901 100644 --- a/lib/auth/join_azure.go +++ b/lib/join/azurejoin/azure.go @@ -1,29 +1,26 @@ -/* - * Teleport - * Copyright (C) 2023 Gravitational, Inc. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package auth +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package azurejoin import ( "cmp" "context" "crypto/x509" "encoding/base64" - "encoding/pem" "log/slog" "net/url" "slices" @@ -37,21 +34,20 @@ import ( "github.com/digitorus/pkcs7" "github.com/go-jose/go-jose/v3/jwt" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/zitadel/oidc/v3/pkg/oidc" - "github.com/gravitational/teleport/api/client" - "github.com/gravitational/teleport/api/client/proto" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/cloud/azure" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/join/joinutils" - "github.com/gravitational/teleport/lib/join/legacyjoin" liboidc "github.com/gravitational/teleport/lib/oidc" "github.com/gravitational/teleport/lib/utils" ) const ( - azureAccessTokenAudience = "https://management.azure.com/" + AzureAccessTokenAudience = "https://management.azure.com/" // azureUserAgent specifies the Azure User-Agent identification for telemetry. azureUserAgent = "teleport" @@ -64,7 +60,8 @@ const ( // Structs for unmarshaling attested data. Schema can be found at // https://learn.microsoft.com/en-us/azure/virtual-machines/linux/instance-metadata-service?tabs=linux#response-2 -type signedAttestedData struct { +// SignedAttestedData models the response from the attested data IMDS endpoint. +type SignedAttestedData struct { Encoding string `json:"encoding"` Signature string `json:"signature"` } @@ -80,7 +77,8 @@ type timestamp struct { ExpiresOn string `json:"expiresOn"` } -type attestedData struct { +// AttestedData models the decoded data returned from the attested data IMDS endpoint. +type AttestedData struct { LicenseType string `json:"licenseType"` Nonce string `json:"nonce"` Plan plan `json:"plan"` @@ -90,7 +88,8 @@ type attestedData struct { SKU string `json:"sku"` } -type accessTokenClaims struct { +// AccessTokenClaims models the claims in an Azure access token. +type AccessTokenClaims struct { oidc.TokenClaims TenantID string `json:"tid"` Version string `json:"ver"` @@ -111,7 +110,7 @@ type accessTokenClaims struct { AzureResourceID string `json:"xms_az_rid"` } -func (c *accessTokenClaims) AsJWTClaims() jwt.Claims { +func (c *AccessTokenClaims) asJWTClaims() jwt.Claims { return jwt.Claims{ Issuer: c.Issuer, Subject: c.Subject, @@ -123,24 +122,36 @@ func (c *accessTokenClaims) AsJWTClaims() jwt.Claims { } } -type azureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) - -type vmClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) - -type azureRegisterConfig struct { - certificateAuthorities []*x509.Certificate - verify azureVerifyTokenFunc - getVMClient vmClientGetter +// AzureVerifyTokenFunc is a function type that verifies an azure VM token. +type AzureVerifyTokenFunc func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error) + +// VMClientGetter is a function type that returns an Azure VM client for a +// given subscription authenticated with a given static token credential. +type VMClientGetter func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) + +// AzureJoinConfig holds configurable options for Azure joining. +type AzureJoinConfig struct { + // CertificateAuthorities, if set, overrides the root certificate + // authorities used to verify VM attested data. + CertificateAuthorities []*x509.Certificate + // Verify, if set, overrides the function used to verify azure VM tokens. + Verify AzureVerifyTokenFunc + // GetVMClient, if set, overrides the function used to get Azure VM clients. + GetVMClient VMClientGetter + // IssuerHTTPClient, if set, overrides the default HTTP client used to + // fetch the intermediate CA which issued the attested data document + // signing certificate. + IssuerHTTPClient utils.HTTPDoClient } -func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc { - return func(ctx context.Context, rawIDToken string) (*accessTokenClaims, error) { +func azureVerifyFuncFromOIDCVerifier(clientID string) AzureVerifyTokenFunc { + return func(ctx context.Context, rawIDToken string) (*AccessTokenClaims, error) { token, err := jwt.ParseSigned(rawIDToken) if err != nil { return nil, trace.Wrap(err) } // Need to get the tenant ID before we verify so we can construct the issuer URL. - var unverifiedClaims accessTokenClaims + var unverifiedClaims AccessTokenClaims if err := token.UnsafeClaimsWithoutVerification(&unverifiedClaims); err != nil { return nil, trace.Wrap(err) } @@ -148,24 +159,24 @@ func azureVerifyFuncFromOIDCVerifier(clientID string) azureVerifyTokenFunc { if err != nil { return nil, trace.Wrap(err) } - return liboidc.ValidateToken[*accessTokenClaims](ctx, issuer, clientID, rawIDToken) + return liboidc.ValidateToken[*AccessTokenClaims](ctx, issuer, clientID, rawIDToken) } } -func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { - if cfg.verify == nil { - cfg.verify = azureVerifyFuncFromOIDCVerifier(azureAccessTokenAudience) +func (cfg *AzureJoinConfig) checkAndSetDefaults() error { + if cfg.Verify == nil { + cfg.Verify = azureVerifyFuncFromOIDCVerifier(AzureAccessTokenAudience) } - if cfg.certificateAuthorities == nil { + if cfg.CertificateAuthorities == nil { certs, err := getAzureRootCerts() if err != nil { return trace.Wrap(err) } - cfg.certificateAuthorities = certs + cfg.CertificateAuthorities = certs } - if cfg.getVMClient == nil { - cfg.getVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { + if cfg.GetVMClient == nil { + cfg.GetVMClient = func(subscriptionID string, token *azure.StaticCredential) (azure.VirtualMachinesClient, error) { // The User-Agent is added for debugging purposes. It helps identify // and isolate teleport traffic. opts := &armpolicy.ClientOptions{ @@ -179,57 +190,58 @@ func (cfg *azureRegisterConfig) CheckAndSetDefaults(ctx context.Context) error { return client, trace.Wrap(err) } } + if cfg.IssuerHTTPClient == nil { + httpClient, err := defaults.HTTPClient() + if err != nil { + return trace.Wrap(err) + } + cfg.IssuerHTTPClient = httpClient + } return nil } -type azureRegisterOption func(cfg *azureRegisterConfig) - // parseAndVeryAttestedData verifies that an attested data document was signed // by Azure. If verification is successful, it returns the ID of the VM that // produced the document. -func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge string, certs []*x509.Certificate) (subscriptionID, vmID string, err error) { - var signedAD signedAttestedData - if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { - return "", "", trace.Wrap(err) - } - if signedAD.Encoding != "pkcs7" { - return "", "", trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) - } - - sigPEM := "-----BEGIN PKCS7-----\n" + signedAD.Signature + "\n-----END PKCS7-----" - sigBER, _ := pem.Decode([]byte(sigPEM)) - if sigBER == nil { - return "", "", trace.AccessDenied("unable to decode attested data document") - } - - p7, err := pkcs7.Parse(sigBER.Bytes) +func parseAndVerifyAttestedData( + ctx context.Context, + cfg *AzureJoinConfig, + adBytes []byte, + intermediates []byte, + challenge string, +) (subscriptionID, vmID string, err error) { + ad, p7, err := ParseAttestedData(adBytes) if err != nil { return "", "", trace.Wrap(err) } - var ad attestedData - if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { - return "", "", trace.Wrap(err) - } if ad.Nonce != challenge { return "", "", trace.AccessDenied("challenge is missing or does not match") } - if len(p7.Certificates) == 0 { return "", "", trace.AccessDenied("no certificates for signature") } fixAzureSigningAlgorithm(p7) - // Azure only sends the leaf cert, so we have to fetch the intermediate. - intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0]) - if err != nil { - return "", "", trace.Wrap(err) - } - if intermediate != nil { - p7.Certificates = append(p7.Certificates, intermediate) + if len(intermediates) > 0 { + // Client explicitly sent intermediate CAs, included them. + intermediates, err := x509.ParseCertificates(intermediates) + if err != nil { + return "", "", trace.Wrap(err, "parsing intermediate certificates sent by client") + } + p7.Certificates = append(p7.Certificates, intermediates...) + } else { + // Client did not send intermediates, fetch them from Azure. + intermediate, err := getAzureIssuerCert(ctx, p7.Certificates[0], cfg.IssuerHTTPClient) + if err != nil { + return "", "", trace.Wrap(err) + } + if intermediate != nil { + p7.Certificates = append(p7.Certificates, intermediate) + } } pool := x509.NewCertPool() - for _, cert := range certs { + for _, cert := range cfg.CertificateAuthorities { pool.AddCert(cert) } @@ -240,18 +252,45 @@ func parseAndVerifyAttestedData(ctx context.Context, adBytes []byte, challenge s return ad.SubscriptionID, ad.ID, nil } +// ParseAttestedData returns the parsed VM attested data and a PKCS7 structure +// which can be used to verify the signature. +func ParseAttestedData(adBytes []byte) (*AttestedData, *pkcs7.PKCS7, error) { + var signedAD SignedAttestedData + if err := utils.FastUnmarshal(adBytes, &signedAD); err != nil { + return nil, nil, trace.Wrap(err) + } + if signedAD.Encoding != "pkcs7" { + return nil, nil, trace.AccessDenied("unsupported signature type: %v", signedAD.Encoding) + } + + sigDER, err := base64.StdEncoding.DecodeString(signedAD.Signature) + if err != nil { + return nil, nil, trace.Wrap(err, "decoding attested data document from base64") + } + + p7, err := pkcs7.Parse(sigDER) + if err != nil { + return nil, nil, trace.Wrap(err) + } + var ad AttestedData + if err := utils.FastUnmarshal(p7.Content, &ad); err != nil { + return nil, nil, trace.Wrap(err) + } + return &ad, p7, nil +} + // verifyVMIdentity verifies that the provided access token came from the // correct Azure VM. Returns the Azure join attributes func verifyVMIdentity( ctx context.Context, - cfg *azureRegisterConfig, + cfg *AzureJoinConfig, accessToken, subscriptionID, vmID string, requestStart time.Time, logger *slog.Logger, ) (joinAttrs *workloadidentityv1pb.JoinAttrsAzure, err error) { - tokenClaims, err := cfg.verify(ctx, accessToken) + tokenClaims, err := cfg.Verify(ctx, accessToken) if err != nil { return nil, trace.Wrap(err) } @@ -270,11 +309,11 @@ func verifyVMIdentity( expectedClaims := jwt.Expected{ Issuer: expectedIssuer, - Audience: jwt.Audience{azureAccessTokenAudience}, + Audience: jwt.Audience{AzureAccessTokenAudience}, Time: requestStart, } - if err := tokenClaims.AsJWTClaims().Validate(expectedClaims); err != nil { + if err := tokenClaims.asJWTClaims().Validate(expectedClaims); err != nil { return nil, trace.Wrap(err) } @@ -299,7 +338,7 @@ func verifyVMIdentity( Token: accessToken, ExpiresOn: tokenClaims.GetExpiration(), }) - vmClient, err := cfg.getVMClient(subscriptionID, tokenCredential) + vmClient, err := cfg.GetVMClient(subscriptionID, tokenCredential) if err != nil { return nil, trace.Wrap(err) } @@ -339,7 +378,7 @@ func verifyVMIdentity( } // claimsToIdentifiers returns the vm identifiers from the provided claims. -func claimsToIdentifiers(tokenClaims *accessTokenClaims) (subscriptionID, resourceGroupID string, err error) { +func claimsToIdentifiers(tokenClaims *AccessTokenClaims) (subscriptionID, resourceGroupID string, err error) { // xms_az_rid claim is omitted when the VM is assigned a System-Assigned Identity. // The xms_mirid claim should be used instead. rid := cmp.Or(tokenClaims.AzureResourceID, tokenClaims.ManangedIdentityResourceID) @@ -369,6 +408,7 @@ func checkAzureAllowRules(vmID string, attrs *workloadidentityv1pb.JoinAttrsAzur } return trace.AccessDenied("instance %v did not match any allow rules in token %v", vmID, token.GetName()) } + func azureResourceGroupIsAllowed(allowedResourceGroups []string, vmResourceGroup string) bool { if len(allowedResourceGroups) == 0 { return true @@ -395,126 +435,82 @@ func azureJoinToAttrs(subscriptionID, resourceGroupID string) *workloadidentityv } } -func (a *Server) checkAzureRequest( - ctx context.Context, - challenge string, - req *proto.RegisterUsingAzureMethodRequest, - cfg *azureRegisterConfig, -) (*workloadidentityv1pb.JoinAttrsAzure, error) { - requestStart := a.clock.Now() - tokenName := req.RegisterUsingTokenRequest.Token - provisionToken, err := a.GetToken(ctx, tokenName) - if err != nil { - return nil, trace.Wrap(err) - } - if provisionToken.GetJoinMethod() != types.JoinMethodAzure { - return nil, trace.AccessDenied("this token does not support the Azure join method") - } - token, ok := provisionToken.(*types.ProvisionTokenV2) - if !ok { - return nil, trace.BadParameter("azure join method only supports ProvisionTokenV2, '%T' was provided", provisionToken) - } - - subID, vmID, err := parseAndVerifyAttestedData(ctx, req.AttestedData, challenge, cfg.certificateAuthorities) - if err != nil { - return nil, trace.Wrap(err) - } - - attrs, err := verifyVMIdentity(ctx, cfg, req.AccessToken, subID, vmID, requestStart, a.logger) - if err != nil { - return nil, trace.Wrap(err) - } - if err := checkAzureAllowRules(vmID, attrs, token); err != nil { - return attrs, trace.Wrap(err) - } - - return attrs, nil +// CheckAzureRequestParams holds all parameters for [CheckAzureRequest]. +type CheckAzureRequestParams struct { + // AzureJoinConfig holds configurable options for Azure joining. + AzureJoinConfig *AzureJoinConfig + // Token is the token used for the incoming request. + Token *types.ProvisionTokenV2 + // Challenge is the challenge that was issued. + Challenge string + // AttestedData is the Azure attested data that was returned by the joining + // client. It must include the challenge as a nonce. + AttestedData []byte + // Intermediate encodes the intermediate CAs that issued the leaf certificate + // used to sign the attested data document, in x509 DER format. + Intermediate []byte + // AccessToken is the Azure access token that was returned by the joining client + AccessToken string + // Logger will be used for logging. + Logger *slog.Logger + // Clock overrides the system time. + Clock clockwork.Clock } -func generateAzureChallenge() (string, error) { - challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24) - return challenge, trace.Wrap(err) +func (p *CheckAzureRequestParams) checkAndSetDefaults() error { + switch { + case p.AzureJoinConfig == nil: + p.AzureJoinConfig = &AzureJoinConfig{} + case p.Token == nil: + return trace.BadParameter("Token is required") + case len(p.Challenge) == 0: + return trace.BadParameter("Challenge is required") + case len(p.AttestedData) == 0: + return trace.BadParameter("AttestedData is required") + case len(p.AccessToken) == 0: + return trace.BadParameter("AccessToken is required") + case p.Logger == nil: + return trace.BadParameter("Logger is required") + case p.Clock == nil: + p.Clock = clockwork.NewRealClock() + } + return trace.Wrap(p.AzureJoinConfig.checkAndSetDefaults()) } -// RegisterUsingAzureMethodWithOpts registers the caller using the Azure join method -// and returns signed certs to join the cluster. -// -// The caller must provide a ChallengeResponseFunc which returns a -// *proto.RegisterUsingAzureMethodRequest with a signed attested data document -// including the challenge as a nonce. -func (a *Server) RegisterUsingAzureMethodWithOpts( - ctx context.Context, - challengeResponse client.RegisterAzureChallengeResponseFunc, - opts ...azureRegisterOption, -) (certs *proto.Certs, err error) { - var provisionToken types.ProvisionToken - var joinRequest *types.RegisterUsingTokenRequest - defer func() { - // Emit a log message and audit event on join failure. - if err != nil { - a.handleJoinFailure(ctx, err, provisionToken, nil, joinRequest) - } - }() - - if legacyjoin.Disabled() { - return nil, trace.Wrap(legacyjoin.ErrDisabled) - } - - cfg := &azureRegisterConfig{} - for _, opt := range opts { - opt(cfg) - } - if err := cfg.CheckAndSetDefaults(ctx); err != nil { +// CheckAzureRequest checks an azure join request by verifying the VMs claims +// and checking that they match an allow rule from the join token. +func CheckAzureRequest(ctx context.Context, params CheckAzureRequestParams) (*workloadidentityv1pb.JoinAttrsAzure, error) { + if err := params.checkAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } + requestStart := params.Clock.Now() - challenge, err := generateAzureChallenge() - if err != nil { - return nil, trace.Wrap(err) - } - req, err := challengeResponse(challenge) + subID, vmID, err := parseAndVerifyAttestedData( + ctx, + params.AzureJoinConfig, + params.AttestedData, + params.Intermediate, + params.Challenge, + ) if err != nil { return nil, trace.Wrap(err) } - joinRequest = req.RegisterUsingTokenRequest - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - - provisionToken, err = a.checkTokenJoinRequestCommon(ctx, req.RegisterUsingTokenRequest) + attrs, err := verifyVMIdentity(ctx, params.AzureJoinConfig, params.AccessToken, subID, vmID, requestStart, params.Logger) if err != nil { return nil, trace.Wrap(err) } - - joinAttrs, err := a.checkAzureRequest(ctx, challenge, req, cfg) - if err != nil { - return nil, trace.Wrap(err) + if err := checkAzureAllowRules(vmID, attrs, params.Token); err != nil { + return attrs, trace.Wrap(err) } - if req.RegisterUsingTokenRequest.Role == types.RoleBot { - params := makeBotCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/, &workloadidentityv1pb.JoinAttrs{ - Azure: joinAttrs, - }) - certs, _, err := a.GenerateBotCertsForJoin(ctx, provisionToken, params) - return certs, trace.Wrap(err) - } - params := makeHostCertsParams(req.RegisterUsingTokenRequest, nil /*rawClaims*/) - certs, err = a.GenerateHostCertsForJoin(ctx, provisionToken, params) - return certs, trace.Wrap(err) + return attrs, nil } -// RegisterUsingAzureMethod registers the caller using the Azure join method -// and returns signed certs to join the cluster. -// -// The caller must provide a ChallengeResponseFunc which returns a -// *proto.RegisterUsingAzureMethodRequest with a signed attested data document -// including the challenge as a nonce. -func (a *Server) RegisterUsingAzureMethod( - ctx context.Context, - challengeResponse client.RegisterAzureChallengeResponseFunc, -) (certs *proto.Certs, err error) { - return a.RegisterUsingAzureMethodWithOpts(ctx, challengeResponse) +// GenerateAzureChallenge generates a challenge for the Azure join method. +func GenerateAzureChallenge() (string, error) { + challenge, err := joinutils.GenerateChallenge(base64.RawURLEncoding, 24) + return challenge, trace.Wrap(err) } // fixAzureSigningAlgorithm fixes a mismatch between the object IDs of the diff --git a/lib/auth/azure_certs.go b/lib/join/azurejoin/azure_certs.go similarity index 98% rename from lib/auth/azure_certs.go rename to lib/join/azurejoin/azure_certs.go index cea969aadc87b..431e7bcecb35b 100644 --- a/lib/auth/azure_certs.go +++ b/lib/join/azurejoin/azure_certs.go @@ -16,7 +16,7 @@ * along with this program. If not, see . */ -package auth +package azurejoin import ( "context" @@ -28,7 +28,6 @@ import ( "github.com/gravitational/trace" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/utils" ) @@ -49,16 +48,11 @@ func isAllowedDomain(cn string, domains []string) bool { } // getAzureIssuerCert fetches a x509 certificate's issuing certificate. -func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate) (*x509.Certificate, error) { +func getAzureIssuerCert(ctx context.Context, cert *x509.Certificate, httpClient utils.HTTPDoClient) (*x509.Certificate, error) { if len(cert.IssuingCertificateURL) == 0 { return nil, nil } - httpClient, err := defaults.HTTPClient() - if err != nil { - return nil, trace.Wrap(err) - } - // Azure sends only one issuing cert. issuerURL := cert.IssuingCertificateURL[0] commonName := cert.Subject.CommonName diff --git a/lib/auth/azure_certs_test.go b/lib/join/azurejoin/azure_certs_test.go similarity index 92% rename from lib/auth/azure_certs_test.go rename to lib/join/azurejoin/azure_certs_test.go index abeb857f0edb7..7ebb1acccf004 100644 --- a/lib/auth/azure_certs_test.go +++ b/lib/join/azurejoin/azure_certs_test.go @@ -16,14 +16,12 @@ * along with this program. If not, see . */ -package auth_test +package azurejoin import ( "testing" "github.com/stretchr/testify/require" - - "github.com/gravitational/teleport/lib/auth" ) func TestIsAllowedDomain(t *testing.T) { @@ -71,7 +69,7 @@ func TestIsAllowedDomain(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - tc.assert(t, auth.IsAllowedDomain(tc.url, allowedDomains)) + tc.assert(t, isAllowedDomain(tc.url, allowedDomains)) }) } } diff --git a/lib/auth/join_azure_test.go b/lib/join/azurejoin/join_azure_test.go similarity index 65% rename from lib/auth/join_azure_test.go rename to lib/join/azurejoin/join_azure_test.go index 50e5bbf77665c..9752766d8d769 100644 --- a/lib/auth/join_azure_test.go +++ b/lib/join/azurejoin/join_azure_test.go @@ -16,15 +16,22 @@ * along with this program. If not, see . */ -package auth_test +package azurejoin_test import ( + "bytes" "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" "crypto/x509" + "crypto/x509/pkix" "encoding/base64" "encoding/json" - "encoding/pem" "fmt" + "io" + "net/http" "testing" "time" @@ -36,13 +43,16 @@ import ( "github.com/stretchr/testify/require" "github.com/zitadel/oidc/v3/pkg/oidc" - "github.com/gravitational/teleport/api/client/proto" + headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" + machineidv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/machineid/v1" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" "github.com/gravitational/teleport/lib/auth/authtest" - "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/auth/machineid/machineidv1" + "github.com/gravitational/teleport/lib/auth/state" "github.com/gravitational/teleport/lib/cloud/azure" - "github.com/gravitational/teleport/lib/fixtures" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/joinclient" + "github.com/gravitational/teleport/lib/tlsca" ) type mockAzureVMClient struct { @@ -67,7 +77,7 @@ func (m *mockAzureVMClient) GetByVMID(_ context.Context, vmID string) (*azure.Vi return nil, trace.NotFound("no vm with id %q", vmID) } -func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClientGetter { +func makeVMClientGetter(clients map[string]*mockAzureVMClient) azurejoin.VMClientGetter { return func(subscriptionID string, _ *azure.StaticCredential) (azure.VirtualMachinesClient, error) { if client, ok := clients[subscriptionID]; ok { return client, nil @@ -76,18 +86,6 @@ func makeVMClientGetter(clients map[string]*mockAzureVMClient) auth.AzureVMClien } } -type azureChallengeResponseConfig struct { - Challenge string -} - -type azureChallengeResponseOption func(*azureChallengeResponseConfig) - -func withChallengeAzure(challenge string) azureChallengeResponseOption { - return func(cfg *azureChallengeResponseConfig) { - cfg.Challenge = challenge - } -} - func vmssResourceID(subscription, resourceGroup, name string) string { return resourceID("Microsoft.Compute/virtualMachineScaleSets", subscription, resourceGroup, name) } @@ -107,8 +105,8 @@ func resourceID(resourceType, subscription, resourceGroup, name string) string { ) } -func mockVerifyToken(err error) auth.AzureVerifyTokenFunc { - return func(_ context.Context, rawToken string) (*auth.AccessTokenClaims, error) { +func mockVerifyToken(err error) azurejoin.AzureVerifyTokenFunc { + return func(_ context.Context, rawToken string) (*azurejoin.AccessTokenClaims, error) { if err != nil { return nil, err } @@ -116,7 +114,7 @@ func mockVerifyToken(err error) auth.AzureVerifyTokenFunc { if err != nil { return nil, trace.Wrap(err) } - var claims auth.AccessTokenClaims + var claims azurejoin.AccessTokenClaims if err := tok.UnsafeClaimsWithoutVerification(&claims); err != nil { return nil, trace.Wrap(err) } @@ -132,10 +130,10 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time if err != nil { return "", trace.Wrap(err) } - claims := auth.AccessTokenClaims{ + claims := azurejoin.AccessTokenClaims{ TokenClaims: oidc.TokenClaims{ Issuer: "https://sts.windows.net/test-tenant-id/", - Audience: []string{auth.AzureAccessTokenAudience}, + Audience: []string{azurejoin.AzureAccessTokenAudience}, Subject: "test", IssuedAt: oidc.FromTime(issueTime), NotBefore: oidc.FromTime(issueTime), @@ -154,28 +152,29 @@ func makeToken(managedIdentityResourceID, azureResourceID string, issueTime time return raw, nil } -func TestAuth_RegisterUsingAzureMethod(t *testing.T) { +func TestJoinAzure(t *testing.T) { t.Parallel() + ctx := t.Context() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - p, err := newTestPack(ctx, t.TempDir()) - require.NoError(t, err) - a := p.a - - sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() - require.NoError(t, err) - - tlsConfig, err := fixtures.LocalTLSConfig() + server, err := authtest.NewTestServer(authtest.ServerConfig{ + Auth: authtest.AuthServerConfig{ + Dir: t.TempDir(), + }, + }) require.NoError(t, err) + a := server.Auth() - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) - tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey) + caChain := newFakeAzureCAChain(t) + httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER) + instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) + instanceCert := caChain.issueLeafCert(t, + instanceKey.Public(), + "instance.metadata.azure.com", + "http://www.microsoft.com/pkiops/certs/testcert.crt") isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) @@ -200,10 +199,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 - challengeResponseOptions []azureChallengeResponseOption + overrideReturnedChallenge string challengeResponseErr error certs []*x509.Certificate - verify auth.AzureVerifyTokenFunc + verify azurejoin.AzureVerifyTokenFunc assertError require.ErrorAssertionFunc }{ { @@ -224,7 +223,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -245,7 +244,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -265,7 +264,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -285,7 +284,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, challengeResponseErr: trace.BadParameter("test error"), assertError: isBadParameter, }, @@ -306,7 +305,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -327,7 +326,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -346,12 +345,10 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { }, JoinMethod: types.JoinMethodAzure, }, - challengeResponseOptions: []azureChallengeResponseOption{ - withChallengeAzure("wrong-challenge"), - }, - verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, - assertError: isAccessDenied, + overrideReturnedChallenge: "wrong-challenge", + verify: mockVerifyToken(nil), + certs: []*x509.Certificate{caChain.rootCert}, + assertError: isAccessDenied, }, { name: "invalid signature", @@ -391,7 +388,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -412,7 +409,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -433,7 +430,7 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -454,13 +451,35 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { + vmClient := &mockAzureVMClient{ + vms: map[string]*azure.VirtualMachine{ + defaultVMResourceID: { + ID: defaultVMResourceID, + Name: defaultVMName, + Subscription: defaultSubscription, + ResourceGroup: defaultResourceGroup, + VMID: defaultVMID, + }, + }, + } + getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ + defaultSubscription: vmClient, + }) + + a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{ + CertificateAuthorities: tc.certs, + Verify: tc.verify, + GetVMClient: getVMClient, + IssuerHTTPClient: httpClient, + }) + token, err := types.NewProvisionTokenFromSpec( "test-token", time.Now().Add(time.Minute), @@ -479,88 +498,76 @@ func TestAuth_RegisterUsingAzureMethod(t *testing.T) { accessToken, err := makeToken(mirID, "", a.GetClock().Now()) require.NoError(t, err) - vmClient := &mockAzureVMClient{ - vms: map[string]*azure.VirtualMachine{ - defaultVMResourceID: { - ID: defaultVMResourceID, - Name: defaultVMName, - Subscription: defaultSubscription, - ResourceGroup: defaultResourceGroup, - VMID: defaultVMID, - }, - }, + imdsClient := &fakeIMDSClient{ + accessToken: accessToken, + accessTokenErr: tc.challengeResponseErr, + overrideChallenge: tc.overrideReturnedChallenge, + signingCert: instanceCert, + signingKey: instanceKey, + subscription: tc.tokenSubscription, + vmID: tc.tokenVMID, } - getVMClient := makeVMClientGetter(map[string]*mockAzureVMClient{ - defaultSubscription: vmClient, - }) - - _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - cfg := &azureChallengeResponseConfig{Challenge: challenge} - for _, opt := range tc.challengeResponseOptions { - opt(cfg) - } - - ad := auth.AttestedData{ - Nonce: cfg.Challenge, - SubscriptionID: tc.tokenSubscription, - ID: tc.tokenVMID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) - require.NoError(t, err) - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: tc.requestTokenName, - HostID: "test-node", - Role: types.RoleNode, - PublicSSHKey: sshPublicKey, - PublicTLSKey: tlsPublicKey, + t.Run("legacy", func(t *testing.T) { + _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + JoinMethod: types.JoinMethodAzure, + ID: state.IdentityID{ + Role: types.RoleInstance, + HostUUID: "testuuid", }, - AttestedData: signedADBytes, - AccessToken: accessToken, - } - return req, tc.challengeResponseErr - }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient)) - tc.assertError(t, err) + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("new", func(t *testing.T) { + _, err = joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleInstance, + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, + }, + }) + tc.assertError(t, err) + }) }) } } // TestAuth_RegisterUsingAzureClaims tests the Azure join method by verifying // joining VMs by the token claims rather than from the Azure VM API. -func TestAuth_RegisterUsingAzureClaims(t *testing.T) { +func TestJoinAzureClaims(t *testing.T) { t.Parallel() + ctx := t.Context() - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - p, err := newTestPack(ctx, t.TempDir()) - require.NoError(t, err) - a := p.a - - sshPrivateKey, sshPublicKey, err := testauthority.New().GenerateKeyPair() - require.NoError(t, err) - - tlsConfig, err := fixtures.LocalTLSConfig() + server, err := authtest.NewTestServer(authtest.ServerConfig{ + Auth: authtest.AuthServerConfig{ + Dir: t.TempDir(), + }, + }) require.NoError(t, err) + a := server.Auth() - block, _ := pem.Decode(fixtures.LocalhostKey) - pkey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + nopClient, err := server.NewClient(authtest.TestNop()) require.NoError(t, err) - tlsPublicKey, err := authtest.PrivateKeyToPublicKeyTLS(sshPrivateKey) + caChain := newFakeAzureCAChain(t) + httpClient := newFakeAzureIssuerHTTPClient(caChain.intermediateCertDER) + instanceKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) + instanceCert := caChain.issueLeafCert(t, + instanceKey.Public(), + "instance.metadata.azure.com", + "http://www.microsoft.com/pkiops/certs/testcert.crt") isAccessDenied := func(t require.TestingT, err error, _ ...any) { require.True(t, trace.IsAccessDenied(err), "expected Access Denied error, actual error: %v", err) @@ -571,6 +578,17 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { defaultIdentityName := "test-id" defaultVMID := "my-vm-id" + botName := "botty" + _, err = machineidv1.UpsertBot(ctx, a, &machineidv1pb.Bot{ + Kind: types.KindBot, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: botName, + }, + Spec: &machineidv1pb.BotSpec{}, + }, a.GetClock().Now(), "") + require.NoError(t, err) + tests := []struct { name string tokenManagedIdentityResourceID string @@ -579,10 +597,9 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { tokenVMID string requestTokenName string tokenSpec types.ProvisionTokenSpecV2 - challengeResponseOptions []azureChallengeResponseOption challengeResponseErr error certs []*x509.Certificate - verify auth.AzureVerifyTokenFunc + verify azurejoin.AzureVerifyTokenFunc assertError require.ErrorAssertionFunc }{ { @@ -604,7 +621,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -626,7 +643,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -648,7 +665,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -671,7 +688,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -694,7 +711,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -717,7 +734,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -740,7 +757,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, { @@ -762,7 +779,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: isAccessDenied, }, { @@ -784,7 +801,7 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { JoinMethod: types.JoinMethodAzure, }, verify: mockVerifyToken(nil), - certs: []*x509.Certificate{tlsConfig.Certificate}, + certs: []*x509.Certificate{caChain.rootCert}, assertError: require.NoError, }, } @@ -813,45 +830,238 @@ func TestAuth_RegisterUsingAzureClaims(t *testing.T) { defaultSubscription: vmClient, }) - _, err = a.RegisterUsingAzureMethodWithOpts(context.Background(), func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error) { - cfg := &azureChallengeResponseConfig{Challenge: challenge} - for _, opt := range tc.challengeResponseOptions { - opt(cfg) - } + a.SetAzureJoinConfig(&azurejoin.AzureJoinConfig{ + CertificateAuthorities: tc.certs, + Verify: tc.verify, + GetVMClient: getVMClient, + IssuerHTTPClient: httpClient, + }) - ad := auth.AttestedData{ - Nonce: cfg.Challenge, - SubscriptionID: tc.tokenSubscription, - ID: tc.tokenVMID, - } - adBytes, err := json.Marshal(&ad) - require.NoError(t, err) - s, err := pkcs7.NewSignedData(adBytes) - require.NoError(t, err) - require.NoError(t, s.AddSigner(tlsConfig.Certificate, pkey, pkcs7.SignerInfoConfig{})) - signature, err := s.Finish() - require.NoError(t, err) - signedAD := auth.SignedAttestedData{ - Encoding: "pkcs7", - Signature: base64.StdEncoding.EncodeToString(signature), - } - signedADBytes, err := json.Marshal(&signedAD) + imdsClient := &fakeIMDSClient{ + accessToken: accessToken, + accessTokenErr: tc.challengeResponseErr, + signingCert: instanceCert, + signingKey: instanceKey, + subscription: tc.tokenSubscription, + vmID: tc.tokenVMID, + } + + t.Run("legacy", func(t *testing.T) { + // Try to join via the legacy join service. + _, err = joinclient.LegacyJoin(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + JoinMethod: types.JoinMethodAzure, + ID: state.IdentityID{ + Role: types.RoleInstance, + HostUUID: "testuuid", + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("new", func(t *testing.T) { + // Try to join via the new join service. + _, err = joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleInstance, + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, + }, + }) + tc.assertError(t, err) + }) + t.Run("bot", func(t *testing.T) { + // Try to join as a bot. + tokenSpec := tc.tokenSpec + tokenSpec.BotName = botName + tokenSpec.Roles = types.SystemRoles{types.RoleBot} + token, err := types.NewProvisionTokenFromSpec( + "test-token", + time.Now().Add(time.Minute), + tokenSpec) require.NoError(t, err) + require.NoError(t, a.UpsertToken(ctx, token)) - req := &proto.RegisterUsingAzureMethodRequest{ - RegisterUsingTokenRequest: &types.RegisterUsingTokenRequest{ - Token: tc.requestTokenName, - HostID: "test-node", - Role: types.RoleNode, - PublicSSHKey: sshPublicKey, - PublicTLSKey: tlsPublicKey, + result, err := joinclient.Join(ctx, joinclient.JoinParams{ + Token: tc.requestTokenName, + ID: state.IdentityID{ + Role: types.RoleBot, + }, + AuthClient: nopClient, + AzureParams: joinclient.AzureParams{ + ClientID: tc.tokenVMID, + IMDSClient: imdsClient, + IssuerHTTPClient: httpClient, }, - AttestedData: signedADBytes, - AccessToken: accessToken, + }) + tc.assertError(t, err) + if err != nil { + return } - return req, tc.challengeResponseErr - }, auth.WithAzureCerts(tc.certs), auth.WithAzureVerifyFunc(tc.verify), auth.WithAzureVMClientGetter(getVMClient)) - tc.assertError(t, err) + + cert, err := tlsca.ParseCertificatePEM(result.Certs.TLS) + require.NoError(t, err) + identity, err := tlsca.FromSubject(cert.Subject, cert.NotAfter) + require.NoError(t, err) + + // Make sure the LoginIP was set on the identity. + require.NotEmpty(t, identity.LoginIP) + + // Make sure the JoinAttributes were set. + require.NotNil(t, identity.JoinAttributes) + require.NotNil(t, identity.JoinAttributes.Azure) + require.Equal(t, tc.tokenSubscription, identity.JoinAttributes.Azure.Subscription) + }) }) } } + +type fakeIMDSClient struct { + accessToken string + accessTokenErr error + + // overrideChallenge overrides the challenge/nonce included in attested data. + overrideChallenge string + signingCert *x509.Certificate + signingKey crypto.Signer + subscription string + vmID string +} + +func (c *fakeIMDSClient) IsAvailable(_ context.Context) bool { + return true +} + +func (c *fakeIMDSClient) GetAttestedData(_ context.Context, nonce string) ([]byte, error) { + ad := azurejoin.AttestedData{ + Nonce: nonce, + SubscriptionID: c.subscription, + ID: c.vmID, + } + if c.overrideChallenge != "" { + ad.Nonce = c.overrideChallenge + } + adBytes, err := json.Marshal(&ad) + if err != nil { + return nil, trace.Wrap(err) + } + s, err := pkcs7.NewSignedData(adBytes) + if err != nil { + return nil, trace.Wrap(err) + } + if err := s.AddSigner(c.signingCert, c.signingKey, pkcs7.SignerInfoConfig{}); err != nil { + return nil, trace.Wrap(err) + } + signature, err := s.Finish() + if err != nil { + return nil, trace.Wrap(err) + } + signedAD := azurejoin.SignedAttestedData{ + Encoding: "pkcs7", + Signature: base64.StdEncoding.EncodeToString(signature), + } + signedADBytes, err := json.Marshal(&signedAD) + if err != nil { + return nil, trace.Wrap(err) + } + return signedADBytes, nil +} + +func (c *fakeIMDSClient) GetAccessToken(_ context.Context, clientID string) (string, error) { + return c.accessToken, trace.Wrap(c.accessTokenErr) +} + +type fakeAzureCAChain struct { + intermediateKey crypto.Signer + intermediateCert *x509.Certificate + intermediateCertDER []byte + rootCert *x509.Certificate +} + +func newFakeAzureCAChain(t *testing.T) *fakeAzureCAChain { + rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + rootCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test root CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + rootCertDER, err := x509.CreateCertificate(rand.Reader, rootCertTemplate, rootCertTemplate, rootKey.Public(), rootKey) + require.NoError(t, err) + rootCert, err := x509.ParseCertificate(rootCertDER) + require.NoError(t, err) + + intermediateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + intermediateCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "test intermediate CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageCertSign, + IsCA: true, + BasicConstraintsValid: true, + } + intermediateCertDER, err := x509.CreateCertificate(rand.Reader, intermediateCertTemplate, rootCert, intermediateKey.Public(), rootKey) + require.NoError(t, err) + intermediateCert, err := x509.ParseCertificate(intermediateCertDER) + require.NoError(t, err) + + return &fakeAzureCAChain{ + intermediateKey: intermediateKey, + intermediateCert: intermediateCert, + intermediateCertDER: intermediateCertDER, + rootCert: rootCert, + } +} + +func (c *fakeAzureCAChain) issueLeafCert(t *testing.T, pub crypto.PublicKey, commonName, issuerURL string) *x509.Certificate { + leafCertTemplate := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: commonName, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + IssuingCertificateURL: []string{issuerURL}, + KeyUsage: x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + } + leafCertDER, err := x509.CreateCertificate(rand.Reader, leafCertTemplate, c.intermediateCert, pub, c.intermediateKey) + require.NoError(t, err) + leafCert, err := x509.ParseCertificate(leafCertDER) + require.NoError(t, err) + return leafCert +} + +type fakeAzureIssuerHTTPClient struct { + issuerCertDER []byte + called int +} + +func newFakeAzureIssuerHTTPClient(issuerCertDER []byte) *fakeAzureIssuerHTTPClient { + return &fakeAzureIssuerHTTPClient{ + issuerCertDER: issuerCertDER, + } +} +func (c *fakeAzureIssuerHTTPClient) Do(req *http.Request) (*http.Response, error) { + c.called++ + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(bytes.NewReader(c.issuerCertDER)), + }, nil +} diff --git a/lib/join/joinclient/join.go b/lib/join/joinclient/join.go index 1374b9b5a030d..7dc1f80cace3b 100644 --- a/lib/join/joinclient/join.go +++ b/lib/join/joinclient/join.go @@ -204,6 +204,7 @@ func joinWithClient(ctx context.Context, params JoinParams, client *joinv1.Clien case types.JoinMethodUnspecified: // leave joinMethodPtr nil to let the server pick based on the token case types.JoinMethodToken, + types.JoinMethodAzure, types.JoinMethodAzureDevops, types.JoinMethodBitbucket, types.JoinMethodBoundKeypair, @@ -306,6 +307,8 @@ func joinWithMethod( switch types.JoinMethod(method) { case types.JoinMethodToken: return tokenJoin(stream, clientParams) + case types.JoinMethodAzure: + return azureJoin(ctx, stream, joinParams, clientParams) case types.JoinMethodAzureDevops: if joinParams.IDToken == "" { joinParams.IDToken, err = azuredevops.NewIDTokenSource(os.Getenv).GetIDToken(ctx) diff --git a/lib/join/joinclient/join_azure.go b/lib/join/joinclient/join_azure.go new file mode 100644 index 0000000000000..a359d68e76f36 --- /dev/null +++ b/lib/join/joinclient/join_azure.go @@ -0,0 +1,135 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package joinclient + +import ( + "context" + "crypto/x509" + "net/http" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/lib/cloud/imds/azure" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/utils" +) + +func azureJoin(ctx context.Context, stream messages.ClientStream, joinParams JoinParams, clientParams messages.ClientParams) (messages.Response, error) { + // The Azure join method involves the following messages: + // + // client->server ClientInit + // client<-server ServerInit + // client->server AzureInit + // client<-server AzureChallenge + // client->server AzureChallengeSolution + // client<-server Result + // + // At this point the ServerInit messages has already been received, what's + // left is to send the AzureInit message, handle the challenge-response, and + // receive and return the final result. + if err := stream.Send(&messages.AzureInit{ + ClientParams: clientParams, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureInit") + } + + challenge, err := messages.RecvResponse[*messages.AzureChallenge](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureChallenge") + } + + imds := joinParams.AzureParams.IMDSClient + if imds == nil { + imds = azure.NewInstanceMetadataClient() + } + if !imds.IsAvailable(ctx) { + return nil, trace.AccessDenied("could not reach instance metadata. Is Teleport running on an Azure VM?") + } + ad, err := imds.GetAttestedData(ctx, challenge.Challenge) + if err != nil { + return nil, trace.Wrap(err, "getting attested data document") + } + intermediate, err := getIntermediate(ctx, joinParams.AzureParams.IssuerHTTPClient, ad) + if err != nil { + return nil, trace.Wrap(err, "getting intermediate CA for attested data") + } + accessToken, err := imds.GetAccessToken(ctx, joinParams.AzureParams.ClientID) + if err != nil { + return nil, trace.Wrap(err, "getting access token") + } + + if err := stream.Send(&messages.AzureChallengeSolution{ + AttestedData: ad, + Intermediate: intermediate, + AccessToken: accessToken, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureChallengeSolution") + } + + result, err := stream.Recv() + return result, trace.Wrap(err, "receiving join result") +} + +func getIntermediate(ctx context.Context, httpClient utils.HTTPDoClient, ad []byte) ([]byte, error) { + _, p7, err := azurejoin.ParseAttestedData(ad) + if err != nil { + return nil, trace.Wrap(err, "parsing attested data document") + } + if len(p7.Certificates) == 0 { + return nil, trace.Errorf("attested data signature has no certificates") + } + leafCert := p7.Certificates[0] + if len(leafCert.IssuingCertificateURL) == 0 { + return nil, trace.Errorf("attested data leaf certificate has no issuing certificate URL") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, leafCert.IssuingCertificateURL[0], nil /*body*/) + if err != nil { + return nil, trace.Wrap(err, "building HTTP request") + } + + if httpClient == nil { + httpClient, err = defaults.HTTPClient() + if err != nil { + return nil, trace.Wrap(err) + } + } + + resp, err := httpClient.Do(req) + if err != nil { + return nil, trace.Wrap(err, "fetching intermediate certificate") + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, trace.Errorf("failed to fetch intermediate cert, got HTTP status code %d", resp.StatusCode) + } + + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) + if err != nil { + return nil, trace.Wrap(err, "reading HTTP response body") + } + + if _, err := x509.ParseCertificates(body); err != nil { + return nil, trace.Wrap(err, "parsing intermediate certificate") + } + + return body, nil +} diff --git a/lib/join/server.go b/lib/join/server.go index ea4601d507a6c..e17b926978858 100644 --- a/lib/join/server.go +++ b/lib/join/server.go @@ -46,6 +46,7 @@ import ( "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/join/azuredevops" + "github.com/gravitational/teleport/lib/join/azurejoin" "github.com/gravitational/teleport/lib/join/bitbucket" "github.com/gravitational/teleport/lib/join/circleci" "github.com/gravitational/teleport/lib/join/ec2join" @@ -104,6 +105,7 @@ type AuthService interface { GetSpaceliftIDTokenValidator() spacelift.Validator GetTPMValidator() tpmjoin.TPMValidator GetTerraformIDTokenValidator() terraformcloud.Validator + GetAzureJoinConfig() *azurejoin.AzureJoinConfig services.Presence } @@ -302,6 +304,8 @@ func (s *Server) handleJoinMethod( joinMethod types.JoinMethod, ) (messages.Response, error) { switch joinMethod { + case types.JoinMethodAzure: + return s.handleAzureJoin(stream, authCtx, clientInit, token) case types.JoinMethodAzureDevops: return s.handleOIDCJoin(stream, authCtx, clientInit, token, s.validateAzureDevopsToken) case types.JoinMethodBitbucket: diff --git a/lib/join/server_azure.go b/lib/join/server_azure.go new file mode 100644 index 0000000000000..c0687eb8753cd --- /dev/null +++ b/lib/join/server_azure.go @@ -0,0 +1,118 @@ +// Teleport +// Copyright (C) 2025 Gravitational, Inc. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package join + +import ( + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/join/azurejoin" + "github.com/gravitational/teleport/lib/join/internal/authz" + "github.com/gravitational/teleport/lib/join/internal/messages" + "github.com/gravitational/teleport/lib/join/provision" +) + +// handleAzureJoin handles join attempts for the Azure join method. +// +// The Azure join method involves the following messages: +// +// client->server ClientInit +// client<-server ServerInit +// client->server AzureInit +// client<-server AzureChallenge +// client->server AzureChallengeSolution +// client<-server Result +// +// At this point the ServerInit message has already been sent, what's left is +// to receive the AzureInit message, handle the challenge-response, and send the +// final result if everything checks out. +func (s *Server) handleAzureJoin( + stream messages.ServerStream, + authCtx *authz.Context, + clientInit *messages.ClientInit, + token provision.Token, +) (messages.Response, error) { + // Receive the AzureInit message from the client. + azureInit, err := messages.RecvRequest[*messages.AzureInit](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureInit message") + } + // Set any diagnostic info from the ClientParams. + setDiagnosticClientParams(stream.Diagnostic(), &azureInit.ClientParams) + + // Generate and send the challenge. + challenge, err := azurejoin.GenerateAzureChallenge() + if err != nil { + return nil, trace.Wrap(err, "generating challenge") + } + if err := stream.Send(&messages.AzureChallenge{ + Challenge: challenge, + }); err != nil { + return nil, trace.Wrap(err, "sending AzureChallenge") + } + + // Receive the solution from the client. + solution, err := messages.RecvRequest[*messages.AzureChallengeSolution](stream) + if err != nil { + return nil, trace.Wrap(err, "receiving AzureChallengeSolution") + } + + switch { + case len(solution.AttestedData) == 0: + return nil, trace.BadParameter("client did not send attested data") + case len(solution.Intermediate) == 0: + return nil, trace.BadParameter("client did not send intermediate CAs") + case len(solution.AccessToken) == 0: + return nil, trace.BadParameter("client did not send access token") + } + + ptv2, ok := token.(*types.ProvisionTokenV2) + if !ok { + return nil, trace.BadParameter("Azure join method only supports ProvisionTokenV2, got %T", token) + } + + // Verify the client's identity and make sure it matches an allow rule in the provision token. + claims, err := azurejoin.CheckAzureRequest(stream.Context(), azurejoin.CheckAzureRequestParams{ + AzureJoinConfig: s.cfg.AuthService.GetAzureJoinConfig(), + Token: ptv2, + Challenge: challenge, + AttestedData: solution.AttestedData, + Intermediate: solution.Intermediate, + AccessToken: solution.AccessToken, + Logger: log, + Clock: s.cfg.AuthService.GetClock(), + }) + if err != nil { + return nil, trace.Wrap(err, "checking Azure challenge solution") + } + + // Make and return the final result message. + result, err := s.makeResult( + stream.Context(), + stream.Diagnostic(), + authCtx, + clientInit, + &azureInit.ClientParams, + token, + nil, // rawClaims + &workloadidentityv1pb.JoinAttrs{ + Azure: claims, + }, + ) + return result, trace.Wrap(err) +}