Skip to content
75 changes: 75 additions & 0 deletions api/client/joinservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ type RegisterIAMChallengeResponseFunc func(challenge string) (*proto.RegisterUsi
// *proto.RegisterUsingAzureMethodRequest for a given challenge, or an error.
type RegisterAzureChallengeResponseFunc func(challenge string) (*proto.RegisterUsingAzureMethodRequest, error)

// RegisterTPMChallengeResponseFunc is a function type meant to be passed to
// RegisterUsingTPMMethod. It must return a
// *proto.RegisterUsingTPMMethodChallengeResponse for a given challenge, or an
// error.
type RegisterTPMChallengeResponseFunc func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error)

// RegisterUsingIAMMethod registers the caller using the IAM join method and
// returns signed certs to join the cluster.
//
Expand Down Expand Up @@ -125,3 +131,72 @@ func (c *JoinServiceClient) RegisterUsingAzureMethod(ctx context.Context, challe
}
return certsResp.Certs, nil
}

// RegisterUsingTPMMethod registers the caller using the TPM join method and
// returns signed certs to join the cluster. The caller must provide a
// ChallengeResponseFunc which returns a *proto.RegisterUsingTPMMethodRequest
// for a given challenge, or an error.
func (c *JoinServiceClient) RegisterUsingTPMMethod(
Comment thread
strideynet marked this conversation as resolved.
ctx context.Context,
initReq *proto.RegisterUsingTPMMethodInitialRequest,
solveChallenge RegisterTPMChallengeResponseFunc,
) (*proto.Certs, error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

stream, err := c.grpcClient.RegisterUsingTPMMethod(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
Comment thread
strideynet marked this conversation as resolved.
defer stream.CloseSend()

err = stream.Send(&proto.RegisterUsingTPMMethodRequest{
Payload: &proto.RegisterUsingTPMMethodRequest_Init{
Init: initReq,
},
})
if err != nil {
return nil, trace.Wrap(err, "sending initial request")
}

res, err := stream.Recv()
if err != nil {
return nil, trace.Wrap(err, "receiving challenge")
}

challenge := res.GetChallengeRequest()
if challenge == nil {
return nil, trace.BadParameter(
"expected ChallengeRequest payload, got %T",
res.Payload,
)
}

solution, err := solveChallenge(challenge)
if err != nil {
return nil, trace.Wrap(err, "solving challenge")
}

err = stream.Send(&proto.RegisterUsingTPMMethodRequest{
Payload: &proto.RegisterUsingTPMMethodRequest_ChallengeResponse{
ChallengeResponse: solution,
},
})
if err != nil {
return nil, trace.Wrap(err, "sending solution")
}

res, err = stream.Recv()
if err != nil {
return nil, trace.Wrap(err, "receiving certs")
}
certs := res.GetCerts()
if certs == nil {
return nil, trace.BadParameter(
"expected Certs payload, got %T",
res.Payload,
)
}

return certs, nil
}
137 changes: 137 additions & 0 deletions api/client/joinservice_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
Copyright 2021 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package client

import (
"context"
"errors"
"net"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"

"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
)

type mockJoinServiceServer struct {
*proto.UnimplementedJoinServiceServer
registerUsingTPMMethod func(srv proto.JoinService_RegisterUsingTPMMethodServer) error
}

func (m *mockJoinServiceServer) RegisterUsingTPMMethod(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
return m.registerUsingTPMMethod(srv)
}

func TestJoinServiceClient_RegisterUsingTPMMethod(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)

lis := bufconn.Listen(100)
t.Cleanup(func() {
assert.NoError(t, lis.Close())
})

mockInitReq := &proto.RegisterUsingTPMMethodInitialRequest{
JoinRequest: &types.RegisterUsingTokenRequest{
Token: "token",
},
}
mockChallenge := &proto.TPMEncryptedCredential{
CredentialBlob: []byte("cred-blob"),
Secret: []byte("secret"),
}
mockChallengeResp := &proto.RegisterUsingTPMMethodChallengeResponse{
Solution: []byte("solution"),
}
mockCerts := &proto.Certs{
TLS: []byte("cert"),
}
mockService := &mockJoinServiceServer{
registerUsingTPMMethod: func(srv proto.JoinService_RegisterUsingTPMMethodServer) error {
req, err := srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(req.GetInit(), mockInitReq))

err = srv.Send(&proto.RegisterUsingTPMMethodResponse{
Payload: &proto.RegisterUsingTPMMethodResponse_ChallengeRequest{
ChallengeRequest: mockChallenge,
},
})
if !assert.NoError(t, err) {
return err
}

req, err = srv.Recv()
if !assert.NoError(t, err) {
return err
}
assert.Empty(t, cmp.Diff(req.GetChallengeResponse(), mockChallengeResp))

err = srv.Send(&proto.RegisterUsingTPMMethodResponse{
Payload: &proto.RegisterUsingTPMMethodResponse_Certs{
Certs: mockCerts,
},
})
if !assert.NoError(t, err) {
return err
}
return nil
},
}
srv := grpc.NewServer()
t.Cleanup(func() {
srv.Stop()
})
proto.RegisterJoinServiceServer(srv, mockService)

go func() {
err := srv.Serve(lis)
if err != nil && !errors.Is(err, grpc.ErrServerStopped) {
assert.NoError(t, err)
}
cancel()
}()

c, err := grpc.NewClient("unused.com", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)

joinClient := NewJoinServiceClient(proto.NewJoinServiceClient(c))

certs, err := joinClient.RegisterUsingTPMMethod(
ctx,
mockInitReq,
func(challenge *proto.TPMEncryptedCredential) (*proto.RegisterUsingTPMMethodChallengeResponse, error) {
assert.Empty(t, cmp.Diff(mockChallenge, challenge))
return mockChallengeResp, nil
},
)
if assert.NoError(t, err) {
assert.Empty(t, cmp.Diff(mockCerts, certs))
}
}
58 changes: 58 additions & 0 deletions api/types/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package types

import (
"crypto/x509"
"encoding/pem"
"fmt"
"slices"
"strings"
Expand Down Expand Up @@ -66,6 +68,9 @@ const (
// method. Documentation regarding implementation of this can be found in
// lib/spacelift.
JoinMethodSpacelift JoinMethod = "spacelift"
// JoinMethodTPM indicates that the node will join with the TPM join method.
// The core implementation of this join method can be found in lib/tpm.
JoinMethodTPM JoinMethod = "tpm"
)

var JoinMethods = []JoinMethod{
Expand All @@ -79,6 +84,7 @@ var JoinMethods = []JoinMethod{
JoinMethodKubernetes,
JoinMethodSpacelift,
JoinMethodToken,
JoinMethodTPM,
}

func ValidateJoinMethod(method JoinMethod) error {
Expand Down Expand Up @@ -328,6 +334,17 @@ func (p *ProvisionTokenV2) CheckAndSetDefaults() error {
if err := providerCfg.checkAndSetDefaults(); err != nil {
return trace.Wrap(err, "spec.spacelift: failed validation")
}
case JoinMethodTPM:
providerCfg := p.Spec.TPM
if providerCfg == nil {
return trace.BadParameter(
`spec.tpm: must be configured for the join method %q`,
JoinMethodTPM,
)
}
if err := providerCfg.validate(); err != nil {
return trace.Wrap(err, "spec.tpm: failed validation")
}
default:
return trace.BadParameter("unknown join method %q", p.Spec.JoinMethod)
}
Expand Down Expand Up @@ -754,3 +771,44 @@ func (a *ProvisionTokenSpecV2Spacelift) checkAndSetDefaults() error {
}
return nil
}

func (a *ProvisionTokenSpecV2TPM) validate() error {
for i, caData := range a.EKCertAllowedCAs {
p, _ := pem.Decode([]byte(caData))
if p == nil {
return trace.BadParameter(
"ekcert_allowed_cas[%d]: no pem block found",
i,
)
}
if p.Type != "CERTIFICATE" {
return trace.BadParameter(
"ekcert_allowed_cas[%d]: pem block is not 'CERTIFICATE' type",
i,
)
}
if _, err := x509.ParseCertificate(p.Bytes); err != nil {
return trace.Wrap(
err,
"ekcert_allowed_cas[%d]: parsing certificate",
i,
)

}
}

if len(a.Allow) == 0 {
return trace.BadParameter(
"allow: at least one rule must be set",
)
}
for i, allowRule := range a.Allow {
if len(allowRule.EKPublicHash) == 0 && len(allowRule.EKCertificateSerial) == 0 {
return trace.BadParameter(
"allow[%d]: at least one of ['ek_public_hash', 'ek_certificate_serial'] must be set",
i,
)
}
}
return nil
}
Loading