Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions api/identityfile/identityfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"fmt"
"io"
"os"
"regexp"
"strings"

"github.com/gravitational/trace"
Expand Down Expand Up @@ -303,14 +302,8 @@ func decodeIdentityFile(idFile io.Reader) (*IdentityFile, error) {
return &ident, nil
}

// OpenSSH cert types look like "<key-type>-cert-v<version>@openssh.com".
// Currently, we only use "ssh-rsa-cert-v01@openssh.com" & "ecdsa-sha2-nistp256-cert-v01@openssh.com".
var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`)

// Check if the given data has an ssh cert type prefix as it's first part.
func isSSHCert(data []byte) bool {
// ssh certs should look like "<ssh-cert-type> <cert-data>",
// so we check if the first element matches a known ssh cert type.
sshCertType := bytes.Split(data, []byte(" "))[0]
return sshCertTypeRegex.Match(sshCertType)
return sshutils.IsSSHCertType(string(sshCertType))
}
10 changes: 10 additions & 0 deletions api/utils/sshutils/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"crypto/subtle"
"io"
"net"
"regexp"

"github.com/gravitational/trace"
"golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -206,3 +207,12 @@ func KeysEqual(ak, bk ssh.PublicKey) bool {
b := bk.Marshal()
return subtle.ConstantTimeCompare(a, b) == 1
}

// OpenSSH cert types look like "<key-type>-cert-v<version>@openssh.com".
var sshCertTypeRegex = regexp.MustCompile(`^[a-z0-9\-]+-cert-v[0-9]{2}@openssh\.com$`)

// IsSSHCertType checks if the given string looks like an ssh cert type.
// e.g. rsa-sha2-256-cert-v01@openssh.com.
func IsSSHCertType(val string) bool {
return sshCertTypeRegex.MatchString(val)
}
2 changes: 1 addition & 1 deletion lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3223,7 +3223,7 @@ func (tc *TeleportClient) connectToProxy(ctx context.Context) (*ProxyClient, err
} else if tc.localAgent != nil {
// Load SSH certs for all clusters we have, in case we don't yet
// have a certificate for tc.SiteName (like during `tsh login leaf`).
signers, err := tc.localAgent.signers()
signers, err := tc.localAgent.Signers()
// errNoLocalKeyStore is returned when running in the proxy. The proxy
// should be passing auth methods via tc.Config.AuthMethods.
if err != nil && !errors.Is(err, errNoLocalKeyStore) && !trace.IsNotFound(err) {
Expand Down
49 changes: 35 additions & 14 deletions lib/client/keyagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,29 +630,50 @@ func (a *LocalKeyAgent) DeleteKeys() error {
return nil
}

// certsForCluster returns a set of ssh.Signers using all certificates
// Signers returns a set of ssh.Signers using all certificates
// for the current proxy and user.
func (a *LocalKeyAgent) signers() ([]ssh.Signer, error) {
k, err := a.GetCoreKey()
if err != nil {
return nil, trace.Wrap(err)
func (a *LocalKeyAgent) Signers() ([]ssh.Signer, error) {
var signers []ssh.Signer

// If we find a valid key store, load all valid ssh certificates as signers.
if k, err := a.GetCoreKey(); err == nil {
certs, err := a.keyStore.GetSSHCertificates(a.proxyHost, a.username)
if err != nil {
return nil, trace.Wrap(err)
}

for _, cert := range certs {
if err := k.checkCert(cert); err != nil {
return nil, trace.Wrap(err)
}
signer, err := sshutils.SSHSigner(cert, k)
if err != nil {
return nil, trace.Wrap(err)
}
signers = append(signers, signer)
}
}

certs, err := a.keyStore.GetSSHCertificates(a.proxyHost, a.username)
// Load all agent certs, including the ones from a local SSH agent.
agentSigners, err := a.ExtendedAgent.Signers()
if err != nil {
return nil, trace.Wrap(err)
}

signers := make([]ssh.Signer, len(certs))
for i, cert := range certs {
if err := k.checkCert(cert); err != nil {
return nil, trace.Wrap(err)
}
signer, err := sshutils.SSHSigner(cert, k)
if a.sshAgent != nil {
sshAgentSigners, err := a.sshAgent.Signers()
if err != nil {
return nil, trace.Wrap(err)
}
signers[i] = signer
agentSigners = append(signers, sshAgentSigners...)
}

// Filter out non-certificates (like regular public SSH keys stored in the SSH agent).
for _, s := range agentSigners {
if _, ok := s.PublicKey().(*ssh.Certificate); ok {
signers = append(signers, s)
} else if k, ok := s.PublicKey().(*agent.Key); ok && sshutils.IsSSHCertType(k.Type()) {
signers = append(signers, s)
}
}

return signers, nil
Expand Down