Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions integration/proxy/proxy_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,25 @@ import (
"crypto/tls"
"crypto/x509/pkix"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"net/url"
"path/filepath"
"strings"
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jackc/pgconn"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/tools/clientcmd"
Expand All @@ -46,7 +51,9 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/fixtures"
"github.com/gravitational/teleport/lib/reversetunnel"
"github.com/gravitational/teleport/lib/service/servicecfg"
"github.com/gravitational/teleport/lib/services"
Expand All @@ -55,6 +62,7 @@ import (
"github.com/gravitational/teleport/lib/srv/db/postgres"
"github.com/gravitational/teleport/lib/tlsca"
"github.com/gravitational/teleport/lib/utils"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)

type Suite struct {
Expand Down Expand Up @@ -692,3 +700,75 @@ func mustParseURL(t *testing.T, rawURL string) *url.URL {
require.NoError(t, err)
return u
}

// fakeSTSClient is a fake HTTP client used to fake STS responses when Auth
// server sends out pre-signed STS requests for IAM join verification.
type fakeSTSClient struct {
accountID string
arn string
credentials *credentials.Credentials
}

func (f fakeSTSClient) Do(req *http.Request) (*http.Response, error) {
if err := awsutils.VerifyAWSSignature(req, f.credentials); err != nil {
return nil, trace.Wrap(err)
}
response := fmt.Sprintf(`{"GetCallerIdentityResponse": {"GetCallerIdentityResult": {"Account": "%s", "Arn": "%s" }}}`, f.accountID, f.arn)
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(response)),
}, nil
}

func mustCreateIAMJoinProvisionToken(t *testing.T, name, awsAccountID, allowedARN string) types.ProvisionToken {
t.Helper()

provisionToken, err := types.NewProvisionTokenFromSpec(
name,
time.Now().Add(time.Hour),
types.ProvisionTokenSpecV2{
Roles: []types.SystemRole{types.RoleNode},
Allow: []*types.TokenRule{
{
AWSAccount: awsAccountID,
AWSARN: allowedARN,
},
},
JoinMethod: types.JoinMethodIAM,
},
)
require.NoError(t, err)
return provisionToken
}

func mustRegisterUsingIAMMethod(t *testing.T, proxyAddr utils.NetAddr, token string, credentials *credentials.Credentials) {
t.Helper()

cred, err := credentials.Get()
require.NoError(t, err)

t.Setenv("AWS_ACCESS_KEY_ID", cred.AccessKeyID)
t.Setenv("AWS_SECRET_ACCESS_KEY", cred.SecretAccessKey)
t.Setenv("AWS_SESSION_TOKEN", cred.SessionToken)
t.Setenv("AWS_REGION", "us-west-2")

privateKey, err := ssh.ParseRawPrivateKey([]byte(fixtures.SSHCAPrivateKey))
require.NoError(t, err)
pubTLS, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(privateKey)
require.NoError(t, err)

node := uuid.NewString()
_, err = auth.Register(auth.RegisterParams{
Token: token,
ID: auth.IdentityID{
Role: types.RoleNode,
HostUUID: node,
NodeName: node,
},
ProxyServer: proxyAddr,
JoinMethod: types.JoinMethodIAM,
PublicTLSKey: pubTLS,
PublicSSHKey: []byte(fixtures.SSHCAPublicKey),
})
require.NoError(t, err, trace.DebugReport(err))
}
38 changes: 38 additions & 0 deletions integration/proxy/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand All @@ -46,6 +47,7 @@ import (
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/integration/kube"
"github.com/gravitational/teleport/lib"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/auth/testauthority"
libclient "github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/defaults"
Expand Down Expand Up @@ -1505,3 +1507,39 @@ func TestALPNProxyHTTPProxyBasicAuthDial(t *testing.T) {
require.NoError(t, helpers.WaitForNodeCount(context.Background(), rc, rc.Secrets.SiteName, 1))
require.Greater(t, ph.Count(), 0)
}

// TestALPNSNIProxyGRPCInsecure tests ALPN protocol ProtocolProxyGRPCInsecure
// by registering a node with IAM join method.
func TestALPNSNIProxyGRPCInsecure(t *testing.T) {
lib.SetInsecureDevMode(true)
defer lib.SetInsecureDevMode(false)

nodeAccount := "123456789012"
nodeRoleARN := "arn:aws:iam::123456789012:role/test"
nodeCredentials := credentials.NewStaticCredentials("FAKE_ID", "FAKE_KEY", "FAKE_TOKEN")
provisionToken := mustCreateIAMJoinProvisionToken(t, "iam-join-token", nodeAccount, nodeRoleARN)

suite := newSuite(t,
withRootClusterConfig(rootClusterStandardConfig(t), func(config *servicecfg.Config) {
config.Auth.BootstrapResources = []types.Resource{provisionToken}
config.Auth.ServerOptions = []auth.ServerOption{auth.WithHTTPClientForAWSSTS(fakeSTSClient{
accountID: nodeAccount,
arn: nodeRoleARN,
credentials: nodeCredentials,
})}
}),
withLeafClusterConfig(leafClusterStandardConfig(t)),
)

// Test register through Proxy.
mustRegisterUsingIAMMethod(t, suite.root.Config.Proxy.WebAddr, provisionToken.GetName(), nodeCredentials)

// Test register through Proxy behind a L7 load balancer.
t.Run("ALPN conn upgrade", func(t *testing.T) {
albProxy := mustStartMockALBProxy(t, suite.root.Config.Proxy.WebAddr.Addr)
albAddr, err := utils.ParseAddr(albProxy.Addr().String())
require.NoError(t, err)

mustRegisterUsingIAMMethod(t, *albAddr, provisionToken.GetName(), nodeCredentials)
})
}
13 changes: 13 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,10 @@ type Server struct {
loginHooksMu sync.RWMutex
// loginHooks are a list of hooks that will be called on login.
loginHooks []LoginHook

// httpClientForAWSSTS overwrites the default HTTP client used for making
// STS requests.
httpClientForAWSSTS stsClient
}

// SetSAMLService registers svc as the SAMLService that provides the SAML
Expand Down Expand Up @@ -5398,3 +5402,12 @@ func DefaultDNSNamesForRole(role types.SystemRole) []string {
}
return nil
}

// WithHTTPClientForAWSSTS is a ServerOption that overwrites default HTTP
// client used for STS requests.
func WithHTTPClientForAWSSTS(client stsClient) ServerOption {
return func(s *Server) error {
s.httpClientForAWSSTS = client
return nil
}
}
2 changes: 1 addition & 1 deletion lib/auth/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ func NewServerIdentity(clt *Server, hostID string, role types.SystemRole) (*Iden
&proto.HostCertsRequest{
HostID: hostID,
NodeName: hostID,
Role: types.RoleAuth,
Role: role,
PublicTLSKey: publicTLS,
PublicSSHKey: pub,
})
Expand Down
19 changes: 5 additions & 14 deletions lib/auth/join_iam.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,21 +196,12 @@ type stsClient interface {
Do(*http.Request) (*http.Response, error)
}

type stsClientKey struct{}

// stsClientFromContext allows the default http client to be overridden for tests
func stsClientFromContext(ctx context.Context) stsClient {
client, ok := ctx.Value(stsClientKey{}).(stsClient)
if ok {
return client
}
return http.DefaultClient
}

// executeSTSIdentityRequest sends the sts:GetCallerIdentity HTTP request to the
// AWS API, parses the response, and returns the awsIdentity
func executeSTSIdentityRequest(ctx context.Context, req *http.Request) (*awsIdentity, error) {
client := stsClientFromContext(ctx)
func executeSTSIdentityRequest(ctx context.Context, client stsClient, req *http.Request) (*awsIdentity, error) {
if client == nil {
client = http.DefaultClient
}

// set the http request context so it can be canceled
req = req.WithContext(ctx)
Expand Down Expand Up @@ -312,7 +303,7 @@ func (a *Server) checkIAMRequest(ctx context.Context, challenge string, req *pro

// send the signed request to the public AWS API and get the node identity
// from the response
identity, err := executeSTSIdentityRequest(ctx, identityRequest)
identity, err := executeSTSIdentityRequest(ctx, a.httpClientForAWSSTS, identityRequest)
if err != nil {
return trace.Wrap(err)
}
Expand Down
4 changes: 3 additions & 1 deletion lib/auth/join_iam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
// Set mock client.
a.httpClientForAWSSTS = tc.stsClient

// add token to auth server
token, err := types.NewProvisionTokenFromSpec(
tc.tokenName,
Expand All @@ -503,7 +506,6 @@ func TestAuth_RegisterUsingIAMMethod(t *testing.T) {

requestContext := context.Background()
requestContext = authz.ContextWithClientAddr(requestContext, &net.IPAddr{})
requestContext = context.WithValue(requestContext, stsClientKey{}, tc.stsClient)

_, err = a.RegisterUsingIAMMethod(requestContext, func(challenge string) (*proto.RegisterUsingIAMMethodRequest, error) {
templateInput := defaultIdentityRequestTemplateInput(challenge)
Expand Down
54 changes: 51 additions & 3 deletions lib/auth/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@ package auth

import (
"context"
"crypto/tls"
"crypto/x509"
"os"
"time"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"golang.org/x/exp/slices"
"golang.org/x/net/http2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
Expand All @@ -32,6 +34,8 @@ import (
"github.com/gravitational/teleport/api/breaker"
"github.com/gravitational/teleport/api/client"
"github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/constants"
apidefaults "github.com/gravitational/teleport/api/defaults"
"github.com/gravitational/teleport/api/metadata"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib"
Expand Down Expand Up @@ -300,7 +304,7 @@ func registerThroughProxy(token string, params RegisterParams) (*proto.Certs, er
switch params.JoinMethod {
case types.JoinMethodIAM, types.JoinMethodAzure:
// IAM and Azure join methods require gRPC client
conn, err := proxyJoinServiceConn(params)
conn, err := proxyJoinServiceConn(params, lib.IsInsecureDevMode())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down Expand Up @@ -402,26 +406,70 @@ func registerThroughAuth(token string, params RegisterParams) (*proto.Certs, err
// proxyJoinServiceConn attempts to connect to the join service running on the
// proxy. The Proxy's TLS cert will be verified using the host's root CA pool
// (PKI) unless the --insecure flag was passed.
func proxyJoinServiceConn(params RegisterParams) (*grpc.ClientConn, error) {
func proxyJoinServiceConn(params RegisterParams, insecure bool) (*grpc.ClientConn, error) {
tlsConfig := utils.TLSConfig(params.CipherSuites)
tlsConfig.Time = params.Clock.Now
// set NextProtos for TLS routing, the actual protocol will be h2
tlsConfig.NextProtos = []string{string(common.ProtocolProxyGRPCInsecure), http2.NextProtoTLS}

if lib.IsInsecureDevMode() {
if insecure {
tlsConfig.InsecureSkipVerify = true
log.Warnf("Joining cluster without validating the identity of the Proxy Server.")
}

// Check if proxy is behind a load balancer. If so, the connection upgrade
// will verify the load balancer's cert using system cert pool. This
// provides the same level of security as the client only verifies Proxy's
// web cert against system cert pool when connection upgrade is not
// required.
//
// With the ALPN connection upgrade, the tunneled TLS Routing request will
// skip verify as the Proxy server will present its host cert which is not
// fully verifiable at this point since the client does not have the host
// CAs yet before completing registration.
alpnConnUpgrade := client.IsALPNConnUpgradeRequired(getHostAddresses(params)[0], insecure)
if alpnConnUpgrade && !insecure {
tlsConfig.InsecureSkipVerify = true
tlsConfig.VerifyConnection = verifyALPNUpgradedConn(params.Clock)
}

dialer := client.NewDialer(
context.Background(),
apidefaults.DefaultIdleTimeout,
apidefaults.DefaultIOTimeout,
client.WithInsecureSkipVerify(insecure),
client.WithALPNConnUpgrade(alpnConnUpgrade),
)

conn, err := grpc.Dial(
getHostAddresses(params)[0],
grpc.WithContextDialer(client.GRPCContextDialer(dialer)),
grpc.WithUnaryInterceptor(metadata.UnaryClientInterceptor),
grpc.WithStreamInterceptor(metadata.StreamClientInterceptor),
grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)),
)
return conn, trace.Wrap(err)
}

// verifyALPNUpgradedConn is a tls.Config.VerifyConnection callback function
// used by the tunneled TLS Routing request to verify the host cert of a Proxy
// behind a L7 load balancer.
//
// Since the client has not obtained the cluster CAs at this point, the
// presented cert cannot be fully verified yet. For now, this function only
// checks if "teleport.cluster.local" is present as one of the DNS names and
// verifies the cert is not expired.
func verifyALPNUpgradedConn(clock clockwork.Clock) func(tls.ConnectionState) error {
return func(server tls.ConnectionState) error {
for _, cert := range server.PeerCertificates {
if slices.Contains(cert.DNSNames, constants.APIDomain) && clock.Now().Before(cert.NotAfter) {
return nil
}
}
return trace.AccessDenied("server is not a Teleport proxy or server certificate is expired")
}
}

// insecureRegisterClient attempts to connects to the Auth Server using the
// CA on disk. If no CA is found on disk, Teleport will not verify the Auth
// Server it is connecting to.
Expand Down
Loading