Skip to content
13 changes: 13 additions & 0 deletions api/profile/profile.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ import (
"github.com/gravitational/teleport/api/utils/keys"
"github.com/gravitational/teleport/api/utils/keys/hardwarekey"
"github.com/gravitational/teleport/api/utils/sshutils"
libdefaults "github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/utils"
)

const (
Expand Down Expand Up @@ -486,3 +488,14 @@ func (p *Profile) AppCertPath(appName string) string {
func (p *Profile) AppKeyPath(appName string) string {
return keypaths.AppKeyPath(p.Dir, p.Name(), p.Username, p.SiteName, appName)
}

// WebProxyHostPort returns the host and port of the web proxy.
func (p *Profile) WebProxyHostPort() (string, int) {
if p.WebProxyAddr != "" {
addr, err := utils.ParseAddr(p.WebProxyAddr)
if err == nil {
return addr.Host(), addr.Port(libdefaults.HTTPListenPort)
}
}
return "unknown", libdefaults.HTTPListenPort
}
4 changes: 4 additions & 0 deletions api/utils/keys/hardwarekey/hardwarekey.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ type Service interface {
// GetFullKeyRef gets the full [PrivateKeyRef] for an existing hardware private
// key in the given slot of the hardware key with the given serial number.
GetFullKeyRef(serialNumber uint32, slotKey PIVSlotKey) (*PrivateKeyRef, error)
// SetPrompt sets the hardware key prompt used by the hardware key service, if applicable.
// This is used by Teleport Connect which sets the prompt later than the hardware key service,
// due to process initialization constraints.
SetPrompt(prompt Prompt)
}

// Signer is a hardware key implementation of [crypto.Signer].
Expand Down
49 changes: 25 additions & 24 deletions api/utils/keys/piv/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,16 @@ import (
"github.com/gravitational/teleport/api/utils/keys/hardwarekey"
)

// TODO(Joerger): Rather than using a global cache and mutexes, clients should be updated
// to create a single YubiKeyService and ensure it is reused across the program execution.
var (
// YubiKeyService is a YubiKey PIV implementation of [hardwarekey.Service].
type YubiKeyService struct {
prompt hardwarekey.Prompt
promptMux sync.Mutex

// yubiKeys is a shared, thread-safe [YubiKey] cache by serial number. It allows for
// separate goroutines to share a YubiKey connection to work around the single PC/SC
// transaction (connection) per-yubikey limit.
yubiKeys map[uint32]*YubiKey = map[uint32]*YubiKey{}
yubiKeys map[uint32]*YubiKey
yubiKeysMux sync.Mutex

// promptMux is used to prevent over-prompting, especially for back-to-back sign requests
// since touch/PIN from the first signature should be cached for following signatures.
promptMux sync.Mutex
)

// YubiKeyService is a YubiKey PIV implementation of [hardwarekey.Service].
type YubiKeyService struct {
prompt hardwarekey.Prompt
}

// Returns a new [YubiKeyService]. If [customPrompt] is nil, the default CLI prompt will be used.
Expand All @@ -64,7 +57,8 @@ func NewYubiKeyService(customPrompt hardwarekey.Prompt) *YubiKeyService {
}

return &YubiKeyService{
prompt: customPrompt,
prompt: customPrompt,
yubiKeys: map[uint32]*YubiKey{},
}
}

Expand Down Expand Up @@ -170,8 +164,8 @@ func (s *YubiKeyService) Sign(ctx context.Context, ref *hardwarekey.PrivateKeyRe
return nil, trace.Wrap(err)
}

promptMux.Lock()
defer promptMux.Unlock()
s.promptMux.Lock()
defer s.promptMux.Unlock()

return y.sign(ctx, ref, keyInfo, s.prompt, rand, digest, opts)
}
Expand Down Expand Up @@ -224,13 +218,20 @@ func (s *YubiKeyService) GetFullKeyRef(serialNumber uint32, slotKey hardwarekey.
return ref, nil
}

// SetPrompt sets the hardware key prompt.
func (s *YubiKeyService) SetPrompt(prompt hardwarekey.Prompt) {
s.promptMux.Lock()
defer s.promptMux.Unlock()
s.prompt = prompt
}

// Get the given YubiKey with the serial number. If the provided serialNumber is "0",
// return the first YubiKey found in the smart card list.
func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) {
yubiKeysMux.Lock()
defer yubiKeysMux.Unlock()
s.yubiKeysMux.Lock()
defer s.yubiKeysMux.Unlock()

if y, ok := yubiKeys[serialNumber]; ok {
if y, ok := s.yubiKeys[serialNumber]; ok {
return y, nil
}

Expand All @@ -239,16 +240,16 @@ func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) {
return nil, trace.Wrap(err)
}

yubiKeys[y.serialNumber] = y
s.yubiKeys[y.serialNumber] = y
return y, nil
}

// checkOrSetPIN prompts the user for PIN and verifies it with the YubiKey.
// If the user provides the default PIN, they will be prompted to set a
// non-default PIN and PUK before continuing.
func (s *YubiKeyService) checkOrSetPIN(ctx context.Context, y *YubiKey, keyInfo hardwarekey.ContextualKeyInfo) error {
promptMux.Lock()
defer promptMux.Unlock()
s.promptMux.Lock()
defer s.promptMux.Unlock()

pin, err := s.prompt.AskPIN(ctx, hardwarekey.PINOptional, keyInfo)
if err != nil {
Expand All @@ -270,8 +271,8 @@ func (s *YubiKeyService) checkOrSetPIN(ctx context.Context, y *YubiKey, keyInfo
}

func (s *YubiKeyService) promptOverwriteSlot(ctx context.Context, msg string, keyInfo hardwarekey.ContextualKeyInfo) error {
promptMux.Lock()
defer promptMux.Unlock()
s.promptMux.Lock()
defer s.promptMux.Unlock()

promptQuestion := fmt.Sprintf("%v\nWould you like to overwrite this slot's private key and certificate?", msg)
if confirmed, confirmErr := s.prompt.ConfirmSlotOverwrite(ctx, promptQuestion, keyInfo); confirmErr != nil {
Expand Down
11 changes: 2 additions & 9 deletions api/utils/keys/piv/service_fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,8 @@ import (
"github.com/gravitational/teleport/api/utils/keys/hardwarekey"
)

// TODO(Joerger): Rather than using a global service, clients should be updated to
// create a single YubiKeyService and ensure it is reused across the program
// execution. At this point, it may make more sense to directly inject the mocked
// hardware key service into the test instead of using the pivtest build tag to do it.
var mockedHardwareKeyService = hardwarekey.NewMockHardwareKeyService(nil /*prompt*/)

// Returns a globally shared [hardwarekey.MockHardwareKeyService]. Test callers should
// Returns a new [hardwarekey.MockHardwareKeyService]. Test callers should
// prefer [hardwarekey.NewMockHardwareKeyService] when possible.
func NewYubiKeyService(prompt hardwarekey.Prompt) *hardwarekey.MockHardwareKeyService {
mockedHardwareKeyService.SetPrompt(prompt)
return mockedHardwareKeyService
return hardwarekey.NewMockHardwareKeyService(prompt)
}
2 changes: 2 additions & 0 deletions api/utils/keys/piv/service_unavailable.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,5 @@ func (s *unavailableYubiKeyPIVService) Sign(_ context.Context, _ *hardwarekey.Pr
func (s *unavailableYubiKeyPIVService) GetFullKeyRef(serialNumber uint32, slotKey hardwarekey.PIVSlotKey) (*hardwarekey.PrivateKeyRef, error) {
return nil, trace.Wrap(errPIVUnavailable)
}

func (s *unavailableYubiKeyPIVService) SetPrompt(_ hardwarekey.Prompt) {}
4 changes: 2 additions & 2 deletions integration/helpers/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -1468,7 +1468,7 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te
HostPort: cfg.Port,
HostLogin: cfg.Login,
InsecureSkipVerify: true,
KeysDir: keyDir,
ClientStore: client.NewFSClientStore(keyDir),
SiteName: cfg.Cluster,
ForwardAgent: fwdAgentMode,
Labels: cfg.Labels,
Expand All @@ -1479,7 +1479,7 @@ func (i *TeleInstance) NewUnauthenticatedClient(cfg ClientConfig) (tc *client.Te
TLSRoutingEnabled: i.IsSinglePortSetup,
TLSRoutingConnUpgradeRequired: cfg.ALBAddr != "",
Tracer: tracing.NoopProvider().Tracer("test"),
EnableEscapeSequences: cfg.EnableEscapeSequences,
DisableEscapeSequences: !cfg.EnableEscapeSequences,
Stderr: cfg.Stderr,
Stdin: cfg.Stdin,
Stdout: cfg.Stdout,
Expand Down
25 changes: 2 additions & 23 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import (
"bytes"
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -75,7 +74,6 @@ import (
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/keypaths"
"github.com/gravitational/teleport/api/utils/prompt"
"github.com/gravitational/teleport/integration/helpers"
"github.com/gravitational/teleport/lib"
Expand Down Expand Up @@ -2439,28 +2437,9 @@ func twoClustersTunnel(t *testing.T, suite *integrationTestSuite, now time.Time,
err = tc.UpdateTrustedCA(ctx, a.GetSiteAPI(a.Secrets.SiteName))
require.NoError(t, err)

// The known_hosts file should have two certificates, the way bytes.Split
// works that means the output will be 3 (2 certs + 1 empty).
buffer, err := os.ReadFile(keypaths.KnownHostsPath(tc.KeysDir))
trustedCerts, err := tc.ClientStore.GetTrustedCerts(tc.WebProxyHost())
require.NoError(t, err)
parts := bytes.Split(buffer, []byte("\n"))
require.Len(t, parts, 3)

roots := x509.NewCertPool()
werr := filepath.Walk(keypaths.CAsDir(tc.KeysDir, Host), func(path string, info fs.FileInfo, err error) error {
require.NoError(t, err)
if info.IsDir() {
return nil
}
buffer, err = os.ReadFile(path)
require.NoError(t, err)
ok := roots.AppendCertsFromPEM(buffer)
require.True(t, ok)
return nil
})
require.NoError(t, werr)
ok := roots.AppendCertsFromPEM(buffer)
require.True(t, ok)
require.Len(t, trustedCerts, 2)

// wait for active tunnel connections to be established
helpers.WaitForActiveTunnelConnections(t, b.Tunnel, a.Secrets.SiteName, 1)
Expand Down
2 changes: 1 addition & 1 deletion integration/kube_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2111,7 +2111,7 @@ func kubeJoin(ctx context.Context, kubeConfig kube.ProxyConfig, tc *client.Telep
KubeProxyAddr: tc.Config.KubeProxyAddr,
WebProxyAddr: tc.Config.WebProxyAddr,
TLSRoutingConnUpgradeRequired: tc.Config.TLSRoutingConnUpgradeRequired,
EnableEscapeSequences: tc.Config.EnableEscapeSequences,
EnableEscapeSequences: !tc.Config.DisableEscapeSequences,
Tracker: meta,
TLSConfig: tlsConfig,
Mode: mode,
Expand Down
5 changes: 3 additions & 2 deletions integration/proxy/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/gravitational/teleport/lib/auth/mocku2f"
wancli "github.com/gravitational/teleport/lib/auth/webauthncli"
wantypes "github.com/gravitational/teleport/lib/auth/webauthntypes"
"github.com/gravitational/teleport/lib/client"
libclient "github.com/gravitational/teleport/lib/client"
"github.com/gravitational/teleport/lib/client/clientcache"
"github.com/gravitational/teleport/lib/client/mfa"
Expand Down Expand Up @@ -241,12 +242,12 @@ func testGatewayCertRenewal(ctx context.Context, t *testing.T, params gatewayCer

fakeClock := clockwork.NewFakeClockAt(time.Now())
storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
// Inject a fake clock into clusters.Storage so we can control when the middleware thinks the
// db cert has expired.
Clock: fakeClock,
WebauthnLogin: webauthnLogin,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -876,8 +877,8 @@ func testTeletermAppGatewayTargetPortValidation(t *testing.T, pack *appaccess.Pa
require.NoError(t, err)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)
daemonService, err := daemon.New(daemon.Config{
Expand Down
20 changes: 10 additions & 10 deletions integration/teleterm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ func testAddingRootCluster(t *testing.T, pack *dbhelpers.DatabasePack, creds *he
t.Helper()

storage, err := clusters.NewStorage(clusters.Config{
Dir: t.TempDir(),
InsecureSkipVerify: true,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -287,8 +287,8 @@ func testListRootClustersReturnsLoggedInUser(t *testing.T, pack *dbhelpers.Datab
tc := mustLogin(t, pack.Root.User.GetName(), pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -369,8 +369,8 @@ func testGetClusterReturnsPropertiesFromAuthServer(t *testing.T, pack *dbhelpers
tc := mustLogin(t, userName, pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -421,8 +421,8 @@ func testHeadlessWatcher(t *testing.T, pack *dbhelpers.DatabasePack, creds *help
tc := mustLogin(t, pack.Root.User.GetName(), pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -488,9 +488,9 @@ func testClientCache(t *testing.T, pack *dbhelpers.DatabasePack, creds *helpers.
storageFakeClock := clockwork.NewFakeClockAt(time.Now())

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
Clock: storageFakeClock,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -748,8 +748,8 @@ func testCreateConnectMyComputerRole(t *testing.T, pack *dbhelpers.DatabasePack)

// Prepare daemon.Service.
storage, err := clusters.NewStorage(clusters.Config{
Dir: t.TempDir(),
InsecureSkipVerify: true,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -862,10 +862,10 @@ func testCreateConnectMyComputerToken(t *testing.T, pack *dbhelpers.DatabasePack

// Prepare daemon.Service.
storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
Clock: fakeClock,
WebauthnLogin: webauthnLogin,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -924,8 +924,8 @@ func testWaitForConnectMyComputerNodeJoin(t *testing.T, pack *dbhelpers.Database
tc := mustLogin(t, pack.Root.User.GetName(), pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -1008,8 +1008,8 @@ func testDeleteConnectMyComputerNode(t *testing.T, pack *dbhelpers.DatabasePack)
tc := mustLogin(t, userName, pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down Expand Up @@ -1235,8 +1235,8 @@ func testListDatabaseUsers(t *testing.T, pack *dbhelpers.DatabasePack) {
tc := mustLogin(t, rootUserName, pack, creds)

storage, err := clusters.NewStorage(clusters.Config{
Dir: tc.KeysDir,
InsecureSkipVerify: tc.InsecureSkipVerify,
ClientStore: client.NewFSClientStore(t.TempDir()),
})
require.NoError(t, err)

Expand Down
8 changes: 4 additions & 4 deletions lib/benchmark/benchmark.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,9 @@ func work(ctx context.Context, m benchMeasure, send chan<- benchMeasure, workloa
// makeTeleportClient creates an instance of a teleport client
func makeTeleportClient(host, login, proxy string) (*client.TeleportClient, error) {
c := client.Config{
Host: host,
Tracer: tracing.NoopProvider().Tracer("test"),
Host: host,
Tracer: tracing.NoopProvider().Tracer("test"),
ClientStore: client.NewFSClientStore(""),
}

if login != "" {
Expand All @@ -295,8 +296,7 @@ func makeTeleportClient(host, login, proxy string) (*client.TeleportClient, erro
c.SSHProxyAddr = proxy
}

profileStore := client.NewFSProfileStore("")
if err := c.LoadProfile(profileStore, proxy); err != nil {
if err := c.LoadProfile(proxy); err != nil {
return nil, trace.Wrap(err)
}
tc, err := client.NewClient(&c)
Expand Down
Loading
Loading