Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions api/types/authority_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,17 @@ func TestRotationZero(t *testing.T) {
require.Equal(t, tt.z, tt.r.IsZero(), tt.d)
}
}

// Test that the spec cluster name name will be set to match the resource name
func TestCheckAndSetDefaults(t *testing.T) {
ca := CertAuthorityV2{
Metadata: Metadata{Name: "caName"},
Spec: CertAuthoritySpecV2{
ClusterName: "clusterName",
Type: HostCA,
},
}
err := ca.CheckAndSetDefaults()
require.NoError(t, err)
require.Equal(t, ca.Metadata.Name, ca.Spec.ClusterName)
}
3 changes: 2 additions & 1 deletion integration/assist/command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ func newTestCredentials(t *testing.T, rc *helpers.TeleInstance, user types.User)
}

pool := x509.NewCertPool()
pool.AppendCertsFromPEM(rc.Secrets.TLSCACert)
pool.AppendCertsFromPEM(rc.Secrets.TLSHostCACert)
pool.AppendCertsFromPEM(rc.Secrets.TLSUserCACert)

tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
Expand Down
111 changes: 67 additions & 44 deletions integration/helpers/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ package helpers
import (
"bytes"
"context"
"crypto/rsa"
"crypto/tls"
"crypto/x509/pkix"
"encoding/json"
Expand All @@ -44,6 +43,7 @@ import (
"github.com/gravitational/teleport/api/breaker"
clientproto "github.com/gravitational/teleport/api/client/proto"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/lib/auth/authclient"
"github.com/gravitational/teleport/lib/auth/keygen"
"github.com/gravitational/teleport/lib/auth/state"
Expand Down Expand Up @@ -97,11 +97,15 @@ type InstanceSecrets struct {
// PrivKey is instance private key
PrivKey []byte `json:"priv"`
// Cert is SSH host certificate
Cert []byte `json:"cert"`
// TLSCACert is the certificate of the trusted certificate authority
TLSCACert []byte `json:"tls_ca_cert"`
// TLSCert is client TLS X509 certificate
TLSCert []byte `json:"tls_cert"`
SSHHostCert []byte `json:"cert"`
// TLSHostCACert is the certificate of the trusted host certificate authority
TLSHostCACert []byte `json:"tls_host_ca_cert"`
// TLSCert is client TLS host X509 certificate
TLSHostCert []byte `json:"tls_host_cert"`
// TLSUserCACert is the certificate of the trusted user certificate authority
TLSUserCACert []byte `json:"tls_user_ca_cert"`
// TLSUserCert is client TLS user X509 certificate
TLSUserCert []byte `json:"tls_user_cert"`
// TunnelAddr is a reverse tunnel listening port, allowing
// other sites to connect to i instance. Set to empty
// string if i instance is not allowing incoming tunnels
Expand Down Expand Up @@ -132,9 +136,7 @@ func (s *InstanceSecrets) GetRoles(t *testing.T) []types.Role {
return roles
}

// GetCAs return an array of CAs stored by the secrets object. In i
// case we always return hard-coded userCA + hostCA (and they share keys
// for simplicity)
// GetCAs return an array of CAs stored by the secrets object
func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
hostCA, err := types.NewCertAuthority(types.CertAuthoritySpecV2{
Type: types.HostCA,
Expand All @@ -148,7 +150,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand All @@ -168,7 +170,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSUserCACert,
}},
},
Roles: []string{services.RoleNameForCertAuthority(s.SiteName)},
Expand All @@ -184,7 +186,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand All @@ -199,7 +201,7 @@ func (s *InstanceSecrets) GetCAs() ([]types.CertAuthority, error) {
TLS: []*types.TLSKeyPair{{
Key: s.PrivKey,
KeyType: types.PrivateKeyType_RAW,
Cert: s.TLSCACert,
Cert: s.TLSHostCACert,
}},
},
})
Expand Down Expand Up @@ -256,9 +258,9 @@ func (s *InstanceSecrets) AsSlice() []*InstanceSecrets {

func (s *InstanceSecrets) GetIdentity() *state.Identity {
i, err := state.ReadIdentityFromKeyPair(s.PrivKey, &clientproto.Certs{
SSH: s.Cert,
TLS: s.TLSCert,
TLSCACerts: [][]byte{s.TLSCACert},
SSH: s.SSHHostCert,
TLS: s.TLSHostCert,
TLSCACerts: [][]byte{s.TLSHostCACert},
})
fatalIf(err)
return i
Expand Down Expand Up @@ -338,20 +340,14 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
if cfg.Priv == nil || cfg.Pub == nil {
cfg.Priv, cfg.Pub, _ = keygen.GenerateKeyPair()
}
rsaKey, err := ssh.ParseRawPrivateKey(cfg.Priv)
key, err := keys.ParsePrivateKey(cfg.Priv)
fatalIf(err)

tlsCACert, err := tlsca.GenerateSelfSignedCAWithSigner(rsaKey.(*rsa.PrivateKey), pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)

signer, err := ssh.ParsePrivateKey(cfg.Priv)
sshSigner, err := ssh.NewSignerFromSigner(key)
fatalIf(err)

cert, err := keygen.GenerateHostCert(services.HostCertParams{
CASigner: signer,
hostCert, err := keygen.GenerateHostCert(services.HostCertParams{
CASigner: sshSigner,
PublicHostKey: cfg.Pub,
HostID: cfg.HostID,
NodeName: cfg.NodeName,
Expand All @@ -360,23 +356,48 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
TTL: 24 * time.Hour,
})
fatalIf(err)
tlsCA, err := tlsca.FromKeys(tlsCACert, cfg.Priv)
fatalIf(err)
cryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
identity := tlsca.Identity{
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
Groups: []string{string(types.RoleAdmin)},
}

clock := cfg.Clock
if clock == nil {
clock = clockwork.NewRealClock()
}

identity := tlsca.Identity{
Username: fmt.Sprintf("%v.%v", cfg.HostID, cfg.ClusterName),
Groups: []string{string(types.RoleAdmin)},
}
subject, err := identity.Subject()
fatalIf(err)
tlsCert, err := tlsCA.GenerateCertificate(tlsca.CertificateRequest{

tlsCAHostCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)
tlsHostCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
fatalIf(err)
hostCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
tlsHostCert, err := tlsHostCA.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
PublicKey: hostCryptoPubKey,
Subject: subject,
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
})
fatalIf(err)

tlsCAUserCert, err := tlsca.GenerateSelfSignedCAWithSigner(key, pkix.Name{
CommonName: cfg.ClusterName,
Organization: []string{cfg.ClusterName},
}, nil, defaults.CATTL)
fatalIf(err)
tlsUserCA, err := tlsca.FromKeys(tlsCAHostCert, cfg.Priv)
fatalIf(err)
userCryptoPubKey, err := sshutils.CryptoPublicKey(cfg.Pub)
fatalIf(err)
tlsUserCert, err := tlsUserCA.GenerateCertificate(tlsca.CertificateRequest{
Clock: clock,
PublicKey: cryptoPubKey,
PublicKey: userCryptoPubKey,
Subject: subject,
NotAfter: clock.Now().UTC().Add(time.Hour * 24),
})
Expand All @@ -391,14 +412,16 @@ func NewInstance(t *testing.T, cfg InstanceConfig) *TeleInstance {
}

secrets := InstanceSecrets{
SiteName: cfg.ClusterName,
PrivKey: cfg.Priv,
PubKey: cfg.Pub,
Cert: cert,
TLSCACert: tlsCACert,
TLSCert: tlsCert,
TunnelAddr: i.ReverseTunnel,
Users: make(map[string]*User),
SiteName: cfg.ClusterName,
PrivKey: cfg.Priv,
PubKey: cfg.Pub,
SSHHostCert: hostCert,
TLSHostCACert: tlsCAHostCert,
TLSHostCert: tlsHostCert,
TLSUserCACert: tlsCAUserCert,
TLSUserCert: tlsUserCert,
TunnelAddr: i.ReverseTunnel,
Users: make(map[string]*User),
}

i.Secrets = secrets
Expand Down
108 changes: 81 additions & 27 deletions lib/auth/authclient/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,59 +39,113 @@ type CAGetter interface {
GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool) ([]types.CertAuthority, error)
}

// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caTypes.
// HostAndUserCAInfo is a map of CA raw subjects and type info for Host
// and User CAs. The key is the RawSubject of the X.509 certificate authority
// (so it's ASN.1 data, not printable).
type HostAndUserCAInfo = map[string]CATypeInfo

// CATypeInfo indicates whether the CA is a host or user CA, or both.
type CATypeInfo struct {
IsHostCA bool
IsUserCA bool
}

// ClientCertPool returns trusted x509 certificate authority pool with CAs provided as caType.
// In addition, it returns the total length of all subjects added to the cert pool, allowing
// the caller to validate that the pool doesn't exceed the maximum 2-byte length prefix before
// using it.
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) (*x509.CertPool, int64, error) {
if len(caTypes) == 0 {
return nil, 0, trace.BadParameter("at least one CA type is required")
func ClientCertPool(ctx context.Context, client CAGetter, clusterName string, caType types.CertAuthType) (*x509.CertPool, int64, error) {
authorities, err := getCACerts(ctx, client, clusterName, caType)
if err != nil {
return nil, 0, trace.Wrap(err)
}

pool := x509.NewCertPool()
var authorities []types.CertAuthority
if clusterName == "" {
for _, caType := range caTypes {
cas, err := client.GetCertAuthorities(ctx, caType, false)
if err != nil {
return nil, 0, trace.Wrap(err)
}
authorities = append(authorities, cas...)
}
} else {
for _, caType := range caTypes {
ca, err := client.GetCertAuthority(
ctx,
types.CertAuthID{Type: caType, DomainName: clusterName},
false)
var totalSubjectsLen int64
for _, auth := range authorities {
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, 0, trace.Wrap(err)
}
pool.AddCert(cert)

authorities = append(authorities, ca)
// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(cert.RawSubject))
}
}
return pool, totalSubjectsLen, nil
}

// DefaultClientCertPool returns default trusted x509 certificate authority pool.
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, HostAndUserCAInfo, int64, error) {
authorities, err := getCACerts(ctx, client, clusterName, types.HostCA, types.UserCA)
if err != nil {
return nil, nil, 0, trace.Wrap(err)
}

pool := x509.NewCertPool()
caInfos := make(HostAndUserCAInfo, len(authorities))
var totalSubjectsLen int64
for _, auth := range authorities {
for _, keyPair := range auth.GetTrustedTLSKeyPairs() {
cert, err := tlsca.ParseCertificatePEM(keyPair.Cert)
if err != nil {
return nil, 0, trace.Wrap(err)
return nil, nil, 0, trace.Wrap(err)
}
pool.AddCert(cert)

caType := auth.GetType()
caInfo := caInfos[string(cert.RawSubject)]
switch caType {
case types.HostCA:
caInfo.IsHostCA = true
case types.UserCA:
caInfo.IsUserCA = true
default:
return nil, nil, 0, trace.BadParameter("unexpected CA type %q", caType)
}
caInfos[string(cert.RawSubject)] = caInfo

// Each subject in the list gets a separate 2-byte length prefix.
totalSubjectsLen += 2
totalSubjectsLen += int64(len(cert.RawSubject))
}
}
return pool, totalSubjectsLen, nil

return pool, caInfos, totalSubjectsLen, nil
}

// DefaultClientCertPool returns default trusted x509 certificate authority pool.
func DefaultClientCertPool(ctx context.Context, client CAGetter, clusterName string) (*x509.CertPool, int64, error) {
return ClientCertPool(ctx, client, clusterName, types.HostCA, types.UserCA)
func getCACerts(ctx context.Context, client CAGetter, clusterName string, caTypes ...types.CertAuthType) ([]types.CertAuthority, error) {
if len(caTypes) == 0 {
return nil, trace.BadParameter("at least one CA type is required")
}

var authorities []types.CertAuthority
if clusterName == "" {
for _, caType := range caTypes {
cas, err := client.GetCertAuthorities(ctx, caType, false)
if err != nil {
return nil, trace.Wrap(err)
}
authorities = append(authorities, cas...)
}
} else {
for _, caType := range caTypes {
ca, err := client.GetCertAuthority(
ctx,
types.CertAuthID{Type: caType, DomainName: clusterName},
false)
if err != nil {
return nil, trace.Wrap(err)
}

authorities = append(authorities, ca)
}
}

return authorities, nil
}

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
Expand All @@ -110,7 +164,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
}
}
}
pool, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
pool, _, totalSubjectsLen, err := DefaultClientCertPool(info.Context(), ap, clusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", clusterName)
// this falls back to the default config
Expand All @@ -132,7 +186,7 @@ func WithClusterCAs(tlsConfig *tls.Config, ap CAGetter, currentClusterName strin
if totalSubjectsLen >= int64(math.MaxUint16) {
log.Debugf("Number of CAs in client cert pool is too large and cannot be encoded in a TLS handshake; this is due to a large number of trusted clusters; will use only the CA of the current cluster to validate.")

pool, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
pool, _, _, err = DefaultClientCertPool(info.Context(), ap, currentClusterName)
if err != nil {
log.WithError(err).Errorf("Failed to retrieve client pool for %q.", currentClusterName)
// this falls back to the default config
Expand Down
Loading