diff --git a/ssh/certs.go b/ssh/certs.go index c7a4dd0ab0..a69e22491d 100644 --- a/ssh/certs.go +++ b/ssh/certs.go @@ -483,6 +483,17 @@ func underlyingAlgo(algo string) string { return algo } +// certificateAlgo returns the certificate algorithms that uses the provided +// underlying signature algorithm. +func certificateAlgo(algo string) (certAlgo string, ok bool) { + for certName, algoName := range certKeyAlgoNames { + if algoName == algo { + return certName, true + } + } + return "", false +} + func (cert *Certificate) bytesForSigning() []byte { c2 := *cert c2.Signature = nil @@ -526,13 +537,11 @@ func (c *Certificate) Marshal() []byte { // Type returns the certificate algorithm name. It is part of the PublicKey interface. func (c *Certificate) Type() string { - keyType := c.Key.Type() - for certName, keyName := range certKeyAlgoNames { - if keyName == keyType { - return certName - } + certName, ok := certificateAlgo(c.Key.Type()) + if !ok { + panic("unknown certificate type for key type " + c.Key.Type()) } - panic("unknown certificate type for key type " + keyType) + return certName } // Verify verifies a signature against the certificate's public diff --git a/ssh/client_auth.go b/ssh/client_auth.go index a962a679a5..409b5ea1d4 100644 --- a/ssh/client_auth.go +++ b/ssh/client_auth.go @@ -234,7 +234,17 @@ func pickSignatureAlgorithm(signer Signer, extensions map[string][]byte) (as Alg return as, keyFormat } + // The server-sig-algs extension only carries underlying signature + // algorithm, but we are trying to select a protocol-level public key + // algorithm, which might be a certificate type. Extend the list of server + // supported algorithms to include the corresponding certificate algorithms. serverAlgos := strings.Split(string(extPayload), ",") + for _, algo := range serverAlgos { + if certAlgo, ok := certificateAlgo(algo); ok { + serverAlgos = append(serverAlgos, certAlgo) + } + } + keyAlgos := algorithmsForKeyFormat(keyFormat) algo, err := findCommon("public key signature algorithm", keyAlgos, serverAlgos) if err != nil {