From d18df75fbe2d6fbc176dffb435416381cd76fc66 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Fri, 6 Jun 2025 21:55:39 +0200 Subject: [PATCH] Avoid counterproductive checks in IsUserAuthority and IsHostAuthority --- api/utils/sshutils/callback.go | 15 +++----- api/utils/sshutils/callback_test.go | 56 +++++++++++++++++++++++++++++ lib/client/keyagent.go | 20 +++++------ lib/devicetrust/testenv/testenv.go | 2 +- lib/srv/authhandlers.go | 42 ++++++++-------------- lib/srv/authhandlers_test.go | 52 +++++++++++++++++++++++++++ lib/utils/cert/certs.go | 7 +--- 7 files changed, 139 insertions(+), 55 deletions(-) create mode 100644 api/utils/sshutils/callback_test.go diff --git a/api/utils/sshutils/callback.go b/api/utils/sshutils/callback.go index c96e4ff39dc8d..a69931f28d870 100644 --- a/api/utils/sshutils/callback.go +++ b/api/utils/sshutils/callback.go @@ -70,23 +70,16 @@ func NewHostKeyCallback(conf HostKeyCallbackConfig) (ssh.HostKeyCallback, error) return checker.CheckHostKey, nil } -func makeIsHostAuthorityFunc(getCheckers CheckersGetter) func(key ssh.PublicKey, host string) bool { - return func(key ssh.PublicKey, host string) bool { +func makeIsHostAuthorityFunc(getCheckers CheckersGetter) func(authority ssh.PublicKey, host string) bool { + return func(authority ssh.PublicKey, host string) bool { checkers, err := getCheckers() if err != nil { slog.ErrorContext(context.Background(), "Failed to get checkers.", "host", host, "error", err) return false } for _, checker := range checkers { - switch v := key.(type) { - case *ssh.Certificate: - if KeysEqual(v.SignatureKey, checker) { - return true - } - default: - if KeysEqual(key, checker) { - return true - } + if KeysEqual(authority, checker) { + return true } } slog.DebugContext(context.Background(), "No CA found for target host.", "host", host) diff --git a/api/utils/sshutils/callback_test.go b/api/utils/sshutils/callback_test.go new file mode 100644 index 0000000000000..78ab289e33b1e --- /dev/null +++ b/api/utils/sshutils/callback_test.go @@ -0,0 +1,56 @@ +// Copyright 2025 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sshutils + +import ( + "crypto/ed25519" + "crypto/rand" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/crypto/ssh" +) + +func TestMakeIsHostAuthorityFunc(t *testing.T) { + rawCA1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca1, err := ssh.NewPublicKey(rawCA1) + require.NoError(t, err) + + rawCA2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca2, err := ssh.NewPublicKey(rawCA2) + require.NoError(t, err) + + rawCA3, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca3, err := ssh.NewPublicKey(rawCA3) + require.NoError(t, err) + + isHostAuthority := makeIsHostAuthorityFunc(func() ([]ssh.PublicKey, error) { + return []ssh.PublicKey{ca1, ca2}, nil + }) + + cert1 := &ssh.Certificate{ + Key: ca1, + SignatureKey: ca1, + } + + require.True(t, isHostAuthority(ca1, "")) + require.True(t, isHostAuthority(ca2, "")) + require.False(t, isHostAuthority(ca3, "")) + + require.False(t, isHostAuthority(cert1, ""), "a certificate signed by a certificate should not pass validation") +} diff --git a/lib/client/keyagent.go b/lib/client/keyagent.go index 65a4f3fc82c93..d34478d9a038f 100644 --- a/lib/client/keyagent.go +++ b/lib/client/keyagent.go @@ -357,7 +357,7 @@ func (a *LocalKeyAgent) HostKeyCallback(addr string, remote net.Addr, hostKey ss certChecker := sshutils.CertChecker{ CertChecker: ssh.CertChecker{ - IsHostAuthority: a.checkHostCertificateForClusters(clusters...), + IsHostAuthority: a.isHostAuthorityForClusters(clusters...), HostKeyFallback: a.checkHostKey, }, FIPS: isFIPS(), @@ -372,11 +372,11 @@ func (a *LocalKeyAgent) HostKeyCallback(addr string, remote net.Addr, hostKey ss return nil } -// checkHostCertificateForClusters validates a host certificate and check if remote key matches the know -// trusted cluster key based on ~/.tsh/known_hosts. If server key is not known, the users is prompted to accept or -// reject the server key. -func (a *LocalKeyAgent) checkHostCertificateForClusters(clusters ...string) func(key ssh.PublicKey, addr string) bool { - return func(key ssh.PublicKey, addr string) bool { +// isHostAuthorityForClusters validates a host certificate's issuer to see if it +// matches the known trusted cluster CA keys in ~/.tsh/known_hosts. If the CA is +// not known, the users is prompted to accept or reject it. +func (a *LocalKeyAgent) isHostAuthorityForClusters(clusters ...string) func(authority ssh.PublicKey, addr string) bool { + return func(authority ssh.PublicKey, addr string) bool { // Check the local cache (where all Teleport CAs are placed upon login) to // see if any of them match. var keys []ssh.PublicKey @@ -402,14 +402,14 @@ func (a *LocalKeyAgent) checkHostCertificateForClusters(clusters ...string) func } for i := range keys { - if sshutils.KeysEqual(key, keys[i]) { + if sshutils.KeysEqual(authority, keys[i]) { return true } } - // If this certificate was not seen before, prompt the user essentially - // treating it like a key. - err = a.checkHostKey(addr, nil, key) + // If this CA was not seen before, prompt the user essentially treating + // it like a key. + err = a.checkHostKey(addr, nil, authority) return err == nil } } diff --git a/lib/devicetrust/testenv/testenv.go b/lib/devicetrust/testenv/testenv.go index a1ee369e628c6..f342736340d06 100644 --- a/lib/devicetrust/testenv/testenv.go +++ b/lib/devicetrust/testenv/testenv.go @@ -223,7 +223,7 @@ func NewSelfSignedSSHCert() ([]byte, crypto.Signer, error) { if err != nil { return nil, nil, trace.Wrap(err) } - sshCert := &ssh.Certificate{Key: sshSigner.PublicKey(), SignatureKey: sshSigner.PublicKey(), Serial: 1, CertType: ssh.UserCert} + sshCert := &ssh.Certificate{Key: sshSigner.PublicKey(), Serial: 1, CertType: ssh.UserCert} if err := sshCert.SignCert(rand.Reader, sshSigner); err != nil { return nil, nil, trace.Wrap(err) } diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 8c4321299744d..3bd06ae98541e 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -571,21 +571,21 @@ func (h *AuthHandlers) hostKeyCallback(hostname string, remote net.Addr, key ssh return nil } -// IsUserAuthority is called during checking the client key, to see if the -// key used to sign the certificate was a Teleport CA. -func (h *AuthHandlers) IsUserAuthority(cert ssh.PublicKey) bool { - if _, err := h.authorityForCert(types.UserCA, cert); err != nil { +// IsUserAuthority is called during checking the issuer of a client certificate, +// to see if it was a Teleport CA. +func (h *AuthHandlers) IsUserAuthority(authority ssh.PublicKey) bool { + if _, err := h.authorityForCert(types.UserCA, authority); err != nil { return false } return true } -// IsHostAuthority is called when checking the host certificate a server -// presents. It make sure that the key used to sign the host certificate was a -// Teleport CA. -func (h *AuthHandlers) IsHostAuthority(cert ssh.PublicKey, address string) bool { - if _, err := h.authorityForCert(types.HostCA, cert); err != nil { +// IsHostAuthority is called when checking the issuer of a host certificate a +// server presents. It make sure that the key used to sign the host certificate +// was a Teleport CA. +func (h *AuthHandlers) IsHostAuthority(authority ssh.PublicKey, address string) bool { + if _, err := h.authorityForCert(types.HostCA, authority); err != nil { h.log.Debugf("Unable to find SSH host CA: %v.", err) return false } @@ -673,9 +673,9 @@ func fetchAccessInfo(ident *sshca.Identity, ca types.CertAuthority, clusterName return accessInfo, trace.Wrap(err) } -// authorityForCert checks if the certificate was signed by a Teleport -// Certificate Authority and returns it. -func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, key ssh.PublicKey) (types.CertAuthority, error) { +// authorityForCert searches for the Teleport Certificate Authority that +// contains the issuer of a certificate and returns it. +func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, authority ssh.PublicKey) (types.CertAuthority, error) { // get all certificate authorities for given type cas, err := h.c.AccessPoint.GetCertAuthorities(context.TODO(), caType, false) if err != nil { @@ -692,21 +692,9 @@ func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, key ssh.Publi return nil, trace.Wrap(err) } for _, checker := range checkers { - // if we have a certificate, compare the certificate signing key against - // the ca key. otherwise check the public key that was passed in. this is - // due to the differences in how this function is called by the user and - // host checkers. - switch v := key.(type) { - case *ssh.Certificate: - if apisshutils.KeysEqual(v.SignatureKey, checker) { - ca = cas[i] - break - } - default: - if apisshutils.KeysEqual(key, checker) { - ca = cas[i] - break - } + if apisshutils.KeysEqual(authority, checker) { + ca = cas[i] + break } } } diff --git a/lib/srv/authhandlers_test.go b/lib/srv/authhandlers_test.go index 58f5bc5d53d0b..646b2dcd845c3 100644 --- a/lib/srv/authhandlers_test.go +++ b/lib/srv/authhandlers_test.go @@ -20,6 +20,8 @@ package srv import ( "context" + "crypto/ed25519" + "crypto/rand" "net" "testing" @@ -542,3 +544,53 @@ func TestRBACJoinMFA(t *testing.T) { }) } } + +func TestAuthorityForCert(t *testing.T) { + rawCA1, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca1, err := ssh.NewPublicKey(rawCA1) + require.NoError(t, err) + + rawCA2, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca2, err := ssh.NewPublicKey(rawCA2) + require.NoError(t, err) + + rawCA3, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + ca3, err := ssh.NewPublicKey(rawCA3) + require.NoError(t, err) + + ah := &AuthHandlers{c: &AuthHandlerConfig{ + Server: (*mockServer)(nil), + AccessPoint: mockCAandAuthPrefGetter{ + cas: map[types.CertAuthType][]types.CertAuthority{ + types.UserCA: {&types.CertAuthorityV2{ + Spec: types.CertAuthoritySpecV2{ + ActiveKeys: types.CAKeySet{ + SSH: []*types.SSHKeyPair{ + {PublicKey: ssh.MarshalAuthorizedKey(ca1)}, + {PublicKey: ssh.MarshalAuthorizedKey(ca2)}, + }, + }, + }, + }}, + }, + }, + }} + + cert1 := &ssh.Certificate{ + Key: ca1, + SignatureKey: ca1, + } + + _, err = ah.authorityForCert(types.UserCA, ca1) + require.NoError(t, err) + _, err = ah.authorityForCert(types.UserCA, ca2) + require.NoError(t, err) + _, err = ah.authorityForCert(types.UserCA, ca3) + require.ErrorAs(t, err, new(*trace.AccessDeniedError)) + + _, err = ah.authorityForCert(types.UserCA, cert1) + require.ErrorAs(t, err, new(*trace.AccessDeniedError), "a certificate signed by a certificate should not pass validation") +} diff --git a/lib/utils/cert/certs.go b/lib/utils/cert/certs.go index 47de43a21a42f..da05058eb5a1e 100644 --- a/lib/utils/cert/certs.go +++ b/lib/utils/cert/certs.go @@ -53,11 +53,7 @@ func createCertificate(principal string, certType uint32, algo cryptosuites.Algo if err != nil { return nil, nil, trace.Wrap(err) } - caPublicKey, err := ssh.NewPublicKey(caKey.Public()) - if err != nil { - return nil, nil, trace.Wrap(err) - } - caSigner, err := ssh.NewSignerFromKey(caKey) + caSigner, err := ssh.NewSignerFromSigner(caKey) if err != nil { return nil, nil, trace.Wrap(err) } @@ -81,7 +77,6 @@ func createCertificate(principal string, certType uint32, algo cryptosuites.Algo KeyId: principal, ValidPrincipals: []string{principal}, Key: publicKey, - SignatureKey: caPublicKey, ValidAfter: uint64(time.Now().UTC().Add(-1 * time.Minute).Unix()), ValidBefore: uint64(time.Now().UTC().Add(1 * time.Minute).Unix()), CertType: certType,