diff --git a/api/identityfile/identityfile.go b/api/identityfile/identityfile.go index 2ac4596ca8f2e..a1c600d0df6ac 100644 --- a/api/identityfile/identityfile.go +++ b/api/identityfile/identityfile.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "os" - "regexp" "strings" "github.com/gravitational/teleport/api/utils/keypaths" @@ -303,14 +302,8 @@ func decodeIdentityFile(idFile io.Reader) (*IdentityFile, error) { return &ident, nil } -// OpenSSH cert types look like "-cert-v@openssh.com". -// Currently, we only use "ssh-rsa-cert-v01@openssh.com" & "ecdsa-sha2-nistp256-cert-v01@openssh.com". -var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`) - // Check if the given data has an ssh cert type prefix as it's first part. func isSSHCert(data []byte) bool { - // ssh certs should look like " ", - // so we check if the first element matches a known ssh cert type. sshCertType := bytes.Split(data, []byte(" "))[0] - return sshCertTypeRegex.Match(sshCertType) + return sshutils.IsSSHCertType(string(sshCertType)) } diff --git a/api/utils/sshutils/ssh.go b/api/utils/sshutils/ssh.go index fd5578a742e21..68c0b8d147391 100644 --- a/api/utils/sshutils/ssh.go +++ b/api/utils/sshutils/ssh.go @@ -23,6 +23,7 @@ import ( "crypto/subtle" "io" "net" + "regexp" "github.com/gravitational/teleport/api/defaults" @@ -206,3 +207,12 @@ func KeysEqual(ak, bk ssh.PublicKey) bool { b := bk.Marshal() return subtle.ConstantTimeCompare(a, b) == 1 } + +// OpenSSH cert types look like "-cert-v@openssh.com". +var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`) + +// IsSSHCertType checks if the given string looks like an ssh cert type. +// e.g. rsa-sha2-256-cert-v01@openssh.com. +func IsSSHCertType(val string) bool { + return sshCertTypeRegex.MatchString(val) +} diff --git a/lib/client/api.go b/lib/client/api.go index 3a8a169ad54f3..1cf5861d19cdb 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3118,7 +3118,7 @@ func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, err } else if tc.localAgent != nil { // Load SSH certs for all clusters we have, in case we don't yet // have a certificate for tc.SiteName (like during `tsh login leaf`). - signers, err := tc.localAgent.signers() + signers, err := tc.localAgent.Signers() // errNoLocalKeyStore is returned when running in the proxy. The proxy // should be passing auth methods via tc.Config.AuthMethods. if err != nil && !errors.Is(err, errNoLocalKeyStore) && !trace.IsNotFound(err) { diff --git a/lib/client/keyagent.go b/lib/client/keyagent.go index fc55b8113035e..4284f5a56e46b 100644 --- a/lib/client/keyagent.go +++ b/lib/client/keyagent.go @@ -632,29 +632,50 @@ func (a *LocalKeyAgent) DeleteKeys() error { return nil } -// certsForCluster returns a set of ssh.Signers using all certificates +// Signers returns a set of ssh.Signers using all certificates // for the current proxy and user. -func (a *LocalKeyAgent) signers() ([]ssh.Signer, error) { - k, err := a.GetCoreKey() - if err != nil { - return nil, trace.Wrap(err) +func (a *LocalKeyAgent) Signers() ([]ssh.Signer, error) { + var signers []ssh.Signer + + // If we find a valid key store, load all valid ssh certificates as signers. + if k, err := a.GetCoreKey(); err == nil { + certs, err := a.keyStore.GetSSHCertificates(a.proxyHost, a.username) + if err != nil { + return nil, trace.Wrap(err) + } + + for _, cert := range certs { + if err := k.checkCert(cert); err != nil { + return nil, trace.Wrap(err) + } + signer, err := sshutils.SSHSigner(cert, k) + if err != nil { + return nil, trace.Wrap(err) + } + signers = append(signers, signer) + } } - certs, err := a.keyStore.GetSSHCertificates(a.proxyHost, a.username) + // Load all agent certs, including the ones from a local SSH agent. + agentSigners, err := a.ExtendedAgent.Signers() if err != nil { return nil, trace.Wrap(err) } - - signers := make([]ssh.Signer, len(certs)) - for i, cert := range certs { - if err := k.checkCert(cert); err != nil { - return nil, trace.Wrap(err) - } - signer, err := sshutils.SSHSigner(cert, k) + if a.sshAgent != nil { + sshAgentSigners, err := a.sshAgent.Signers() if err != nil { return nil, trace.Wrap(err) } - signers[i] = signer + agentSigners = append(signers, sshAgentSigners...) + } + + // Filter out non-certificates (like regular public SSH keys stored in the SSH agent). + for _, s := range agentSigners { + if _, ok := s.PublicKey().(*ssh.Certificate); ok { + signers = append(signers, s) + } else if k, ok := s.PublicKey().(*agent.Key); ok && sshutils.IsSSHCertType(k.Type()) { + signers = append(signers, s) + } } return signers, nil