Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions api/utils/sshutils/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,23 +70,16 @@ func NewHostKeyCallback(conf HostKeyCallbackConfig) (ssh.HostKeyCallback, error)
return checker.CheckHostKey, nil
}

func makeIsHostAuthorityFunc(getCheckers CheckersGetter) func(key ssh.PublicKey, host string) bool {
return func(key ssh.PublicKey, host string) bool {
func makeIsHostAuthorityFunc(getCheckers CheckersGetter) func(authority ssh.PublicKey, host string) bool {
return func(authority ssh.PublicKey, host string) bool {
checkers, err := getCheckers()
if err != nil {
slog.ErrorContext(context.Background(), "Failed to get checkers.", "host", host, "error", err)
return false
}
for _, checker := range checkers {
switch v := key.(type) {
case *ssh.Certificate:
if KeysEqual(v.SignatureKey, checker) {
return true
}
default:
if KeysEqual(key, checker) {
return true
}
if KeysEqual(authority, checker) {
return true
}
}
slog.DebugContext(context.Background(), "No CA found for target host.", "host", host)
Expand Down
56 changes: 56 additions & 0 deletions api/utils/sshutils/callback_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2025 Gravitational, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package sshutils

import (
"crypto/ed25519"
"crypto/rand"
"testing"

"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
)

func TestMakeIsHostAuthorityFunc(t *testing.T) {
rawCA1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca1, err := ssh.NewPublicKey(rawCA1)
require.NoError(t, err)

rawCA2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca2, err := ssh.NewPublicKey(rawCA2)
require.NoError(t, err)

rawCA3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca3, err := ssh.NewPublicKey(rawCA3)
require.NoError(t, err)

isHostAuthority := makeIsHostAuthorityFunc(func() ([]ssh.PublicKey, error) {
return []ssh.PublicKey{ca1, ca2}, nil
})

cert1 := &ssh.Certificate{
Key: ca1,
SignatureKey: ca1,
}

require.True(t, isHostAuthority(ca1, ""))
require.True(t, isHostAuthority(ca2, ""))
require.False(t, isHostAuthority(ca3, ""))

require.False(t, isHostAuthority(cert1, ""), "a certificate signed by a certificate should not pass validation")
}
20 changes: 10 additions & 10 deletions lib/client/keyagent.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ func (a *LocalKeyAgent) HostKeyCallback(addr string, remote net.Addr, hostKey ss

certChecker := sshutils.CertChecker{
CertChecker: ssh.CertChecker{
IsHostAuthority: a.checkHostCertificateForClusters(clusters...),
IsHostAuthority: a.isHostAuthorityForClusters(clusters...),
HostKeyFallback: a.checkHostKey,
},
FIPS: isFIPS(),
Expand All @@ -372,11 +372,11 @@ func (a *LocalKeyAgent) HostKeyCallback(addr string, remote net.Addr, hostKey ss
return nil
}

// checkHostCertificateForClusters validates a host certificate and check if remote key matches the know
// trusted cluster key based on ~/.tsh/known_hosts. If server key is not known, the users is prompted to accept or
// reject the server key.
func (a *LocalKeyAgent) checkHostCertificateForClusters(clusters ...string) func(key ssh.PublicKey, addr string) bool {
return func(key ssh.PublicKey, addr string) bool {
// isHostAuthorityForClusters validates a host certificate's issuer to see if it
// matches the known trusted cluster CA keys in ~/.tsh/known_hosts. If the CA is
// not known, the users is prompted to accept or reject it.
func (a *LocalKeyAgent) isHostAuthorityForClusters(clusters ...string) func(authority ssh.PublicKey, addr string) bool {
return func(authority ssh.PublicKey, addr string) bool {
// Check the local cache (where all Teleport CAs are placed upon login) to
// see if any of them match.
var keys []ssh.PublicKey
Expand All @@ -402,14 +402,14 @@ func (a *LocalKeyAgent) checkHostCertificateForClusters(clusters ...string) func
}

for i := range keys {
if sshutils.KeysEqual(key, keys[i]) {
if sshutils.KeysEqual(authority, keys[i]) {
return true
}
}

// If this certificate was not seen before, prompt the user essentially
// treating it like a key.
err = a.checkHostKey(addr, nil, key)
// If this CA was not seen before, prompt the user essentially treating
// it like a key.
err = a.checkHostKey(addr, nil, authority)
return err == nil
}
}
Expand Down
2 changes: 1 addition & 1 deletion lib/devicetrust/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ func NewSelfSignedSSHCert() ([]byte, crypto.Signer, error) {
if err != nil {
return nil, nil, trace.Wrap(err)
}
sshCert := &ssh.Certificate{Key: sshSigner.PublicKey(), SignatureKey: sshSigner.PublicKey(), Serial: 1, CertType: ssh.UserCert}
sshCert := &ssh.Certificate{Key: sshSigner.PublicKey(), Serial: 1, CertType: ssh.UserCert}
if err := sshCert.SignCert(rand.Reader, sshSigner); err != nil {
return nil, nil, trace.Wrap(err)
}
Expand Down
42 changes: 15 additions & 27 deletions lib/srv/authhandlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,21 +571,21 @@ func (h *AuthHandlers) hostKeyCallback(hostname string, remote net.Addr, key ssh
return nil
}

// IsUserAuthority is called during checking the client key, to see if the
// key used to sign the certificate was a Teleport CA.
func (h *AuthHandlers) IsUserAuthority(cert ssh.PublicKey) bool {
if _, err := h.authorityForCert(types.UserCA, cert); err != nil {
// IsUserAuthority is called during checking the issuer of a client certificate,
// to see if it was a Teleport CA.
func (h *AuthHandlers) IsUserAuthority(authority ssh.PublicKey) bool {
if _, err := h.authorityForCert(types.UserCA, authority); err != nil {
return false
}

return true
}

// IsHostAuthority is called when checking the host certificate a server
// presents. It make sure that the key used to sign the host certificate was a
// Teleport CA.
func (h *AuthHandlers) IsHostAuthority(cert ssh.PublicKey, address string) bool {
if _, err := h.authorityForCert(types.HostCA, cert); err != nil {
// IsHostAuthority is called when checking the issuer of a host certificate a
// server presents. It make sure that the key used to sign the host certificate
// was a Teleport CA.
func (h *AuthHandlers) IsHostAuthority(authority ssh.PublicKey, address string) bool {
if _, err := h.authorityForCert(types.HostCA, authority); err != nil {
h.log.Debugf("Unable to find SSH host CA: %v.", err)
return false
}
Expand Down Expand Up @@ -673,9 +673,9 @@ func fetchAccessInfo(ident *sshca.Identity, ca types.CertAuthority, clusterName
return accessInfo, trace.Wrap(err)
}

// authorityForCert checks if the certificate was signed by a Teleport
// Certificate Authority and returns it.
func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, key ssh.PublicKey) (types.CertAuthority, error) {
// authorityForCert searches for the Teleport Certificate Authority that
// contains the issuer of a certificate and returns it.
func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, authority ssh.PublicKey) (types.CertAuthority, error) {
// get all certificate authorities for given type
cas, err := h.c.AccessPoint.GetCertAuthorities(context.TODO(), caType, false)
if err != nil {
Expand All @@ -692,21 +692,9 @@ func (h *AuthHandlers) authorityForCert(caType types.CertAuthType, key ssh.Publi
return nil, trace.Wrap(err)
}
for _, checker := range checkers {
// if we have a certificate, compare the certificate signing key against
// the ca key. otherwise check the public key that was passed in. this is
// due to the differences in how this function is called by the user and
// host checkers.
switch v := key.(type) {
case *ssh.Certificate:
if apisshutils.KeysEqual(v.SignatureKey, checker) {
ca = cas[i]
break
}
default:
if apisshutils.KeysEqual(key, checker) {
ca = cas[i]
break
}
if apisshutils.KeysEqual(authority, checker) {
ca = cas[i]
break
}
}
}
Expand Down
52 changes: 52 additions & 0 deletions lib/srv/authhandlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package srv

import (
"context"
"crypto/ed25519"
"crypto/rand"
"net"
"testing"

Expand Down Expand Up @@ -542,3 +544,53 @@ func TestRBACJoinMFA(t *testing.T) {
})
}
}

func TestAuthorityForCert(t *testing.T) {
rawCA1, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca1, err := ssh.NewPublicKey(rawCA1)
require.NoError(t, err)

rawCA2, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca2, err := ssh.NewPublicKey(rawCA2)
require.NoError(t, err)

rawCA3, _, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err)
ca3, err := ssh.NewPublicKey(rawCA3)
require.NoError(t, err)

ah := &AuthHandlers{c: &AuthHandlerConfig{
Server: (*mockServer)(nil),
AccessPoint: mockCAandAuthPrefGetter{
cas: map[types.CertAuthType][]types.CertAuthority{
types.UserCA: {&types.CertAuthorityV2{
Spec: types.CertAuthoritySpecV2{
ActiveKeys: types.CAKeySet{
SSH: []*types.SSHKeyPair{
{PublicKey: ssh.MarshalAuthorizedKey(ca1)},
{PublicKey: ssh.MarshalAuthorizedKey(ca2)},
},
},
},
}},
},
},
}}

cert1 := &ssh.Certificate{
Key: ca1,
SignatureKey: ca1,
}

_, err = ah.authorityForCert(types.UserCA, ca1)
require.NoError(t, err)
_, err = ah.authorityForCert(types.UserCA, ca2)
require.NoError(t, err)
_, err = ah.authorityForCert(types.UserCA, ca3)
require.ErrorAs(t, err, new(*trace.AccessDeniedError))

_, err = ah.authorityForCert(types.UserCA, cert1)
require.ErrorAs(t, err, new(*trace.AccessDeniedError), "a certificate signed by a certificate should not pass validation")
}
7 changes: 1 addition & 6 deletions lib/utils/cert/certs.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,7 @@ func createCertificate(principal string, certType uint32, algo cryptosuites.Algo
if err != nil {
return nil, nil, trace.Wrap(err)
}
caPublicKey, err := ssh.NewPublicKey(caKey.Public())
if err != nil {
return nil, nil, trace.Wrap(err)
}
caSigner, err := ssh.NewSignerFromKey(caKey)
caSigner, err := ssh.NewSignerFromSigner(caKey)
if err != nil {
return nil, nil, trace.Wrap(err)
}
Expand All @@ -81,7 +77,6 @@ func createCertificate(principal string, certType uint32, algo cryptosuites.Algo
KeyId: principal,
ValidPrincipals: []string{principal},
Key: publicKey,
SignatureKey: caPublicKey,
ValidAfter: uint64(time.Now().UTC().Add(-1 * time.Minute).Unix()),
ValidBefore: uint64(time.Now().UTC().Add(1 * time.Minute).Unix()),
CertType: certType,
Expand Down
Loading