diff --git a/api/utils/keys/piv/service.go b/api/utils/keys/piv/service.go index efbdb174cc25e..798ca5385911c 100644 --- a/api/utils/keys/piv/service.go +++ b/api/utils/keys/piv/service.go @@ -35,23 +35,25 @@ import ( "github.com/gravitational/teleport/api/utils/keys/hardwarekey" ) -// TODO(Joerger): Rather than using a global cache and mutexes, clients should be updated -// to create a single YubiKeyService and ensure it is reused across the program execution. -var ( +// yubiKeyService is a global YubiKeyService used to share yubikey connections +// and prompt mutex logic across the process in cases where [NewYubiKeyService] +// is called multiple times. +// +// TODO(Joerger): Ensure all clients initialize [NewYubiKeyService] only once so we can +// remove this global variable. +var yubiKeyService *YubiKeyService +var yubiKeyServiceMux sync.Mutex + +// YubiKeyService is a YubiKey PIV implementation of [hardwarekey.Service]. +type YubiKeyService struct { + prompt hardwarekey.Prompt + promptMux sync.Mutex + // yubiKeys is a shared, thread-safe [YubiKey] cache by serial number. It allows for // separate goroutines to share a YubiKey connection to work around the single PC/SC // transaction (connection) per-yubikey limit. - yubiKeys map[uint32]*YubiKey = map[uint32]*YubiKey{} + yubiKeys map[uint32]*YubiKey yubiKeysMux sync.Mutex - - // promptMux is used to prevent over-prompting, especially for back-to-back sign requests - // since touch/PIN from the first signature should be cached for following signatures. - promptMux sync.Mutex -) - -// YubiKeyService is a YubiKey PIV implementation of [hardwarekey.Service]. -type YubiKeyService struct { - prompt hardwarekey.Prompt } // Returns a new [YubiKeyService]. If [customPrompt] is nil, the default CLI prompt will be used. @@ -59,13 +61,26 @@ type YubiKeyService struct { // Only a single service should be created for each process to ensure the cached connections // are shared and multiple services don't compete for PIV resources. func NewYubiKeyService(customPrompt hardwarekey.Prompt) *YubiKeyService { + yubiKeyServiceMux.Lock() + defer yubiKeyServiceMux.Unlock() + + if yubiKeyService != nil { + // If a prompt is provided, prioritize it over the existing prompt value. + if customPrompt != nil { + yubiKeyService.prompt = customPrompt + } + return yubiKeyService + } + if customPrompt == nil { customPrompt = hardwarekey.NewStdCLIPrompt() } - return &YubiKeyService{ - prompt: customPrompt, + yubiKeyService = &YubiKeyService{ + prompt: customPrompt, + yubiKeys: map[uint32]*YubiKey{}, } + return yubiKeyService } // NewPrivateKey creates a hardware private key that satisfies the provided [config], @@ -170,8 +185,8 @@ func (s *YubiKeyService) Sign(ctx context.Context, ref *hardwarekey.PrivateKeyRe return nil, trace.Wrap(err) } - promptMux.Lock() - defer promptMux.Unlock() + s.promptMux.Lock() + defer s.promptMux.Unlock() return y.sign(ctx, ref, keyInfo, s.prompt, rand, digest, opts) } @@ -227,10 +242,10 @@ func (s *YubiKeyService) GetFullKeyRef(serialNumber uint32, slotKey hardwarekey. // Get the given YubiKey with the serial number. If the provided serialNumber is "0", // return the first YubiKey found in the smart card list. func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) { - yubiKeysMux.Lock() - defer yubiKeysMux.Unlock() + s.yubiKeysMux.Lock() + defer s.yubiKeysMux.Unlock() - if y, ok := yubiKeys[serialNumber]; ok { + if y, ok := s.yubiKeys[serialNumber]; ok { return y, nil } @@ -239,7 +254,7 @@ func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) { return nil, trace.Wrap(err) } - yubiKeys[y.serialNumber] = y + s.yubiKeys[y.serialNumber] = y return y, nil } @@ -247,8 +262,8 @@ func (s *YubiKeyService) getYubiKey(serialNumber uint32) (*YubiKey, error) { // If the user provides the default PIN, they will be prompted to set a // non-default PIN and PUK before continuing. func (s *YubiKeyService) checkOrSetPIN(ctx context.Context, y *YubiKey, keyInfo hardwarekey.ContextualKeyInfo) error { - promptMux.Lock() - defer promptMux.Unlock() + s.promptMux.Lock() + defer s.promptMux.Unlock() pin, err := s.prompt.AskPIN(ctx, hardwarekey.PINOptional, keyInfo) if err != nil { @@ -270,8 +285,8 @@ func (s *YubiKeyService) checkOrSetPIN(ctx context.Context, y *YubiKey, keyInfo } func (s *YubiKeyService) promptOverwriteSlot(ctx context.Context, msg string, keyInfo hardwarekey.ContextualKeyInfo) error { - promptMux.Lock() - defer promptMux.Unlock() + s.promptMux.Lock() + defer s.promptMux.Unlock() promptQuestion := fmt.Sprintf("%v\nWould you like to overwrite this slot's private key and certificate?", msg) if confirmed, confirmErr := s.prompt.ConfirmSlotOverwrite(ctx, promptQuestion, keyInfo); confirmErr != nil { diff --git a/api/utils/keys/piv/service_test.go b/api/utils/keys/piv/service_test.go index ffe45505d297c..1bce1e1f64441 100644 --- a/api/utils/keys/piv/service_test.go +++ b/api/utils/keys/piv/service_test.go @@ -17,6 +17,7 @@ package piv_test import ( + "bytes" "context" "crypto/x509/pkix" "fmt" @@ -48,7 +49,9 @@ func TestGetYubiKeyPrivateKey_Interactive(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - s := piv.NewYubiKeyService(hardwarekey.NewStdCLIPrompt()) + promptReader := prompt.NewFakeReader() + prompt := hardwarekey.NewCLIPrompt(os.Stderr, promptReader) + s := piv.NewYubiKeyService(prompt) y, err := piv.FindYubiKey(0) require.NoError(t, err) @@ -64,6 +67,13 @@ func TestGetYubiKeyPrivateKey_Interactive(t *testing.T) { require.NoError(t, err) require.NoError(t, priv.WarmupHardwareKey(ctx)) + // Set pin and handle expected prompts. + setupPINPrompt := func(t *testing.T) { + const testPIN = "123123" + require.NoError(t, y.SetPIN(pivgo.DefaultPIN, testPIN)) + promptReader.AddString(testPIN).AddString(testPIN) + } + for _, policy := range []hardwarekey.PromptPolicy{ hardwarekey.PromptPolicyNone, hardwarekey.PromptPolicyTouch, @@ -73,8 +83,10 @@ func TestGetYubiKeyPrivateKey_Interactive(t *testing.T) { for _, customSlot := range []bool{true, false} { t.Run(fmt.Sprintf("policy:%+v", policy), func(t *testing.T) { t.Run(fmt.Sprintf("custom slot:%v", customSlot), func(t *testing.T) { - resetYubikey(t, y) - setupPINPrompt(t, y) + setupPINPrompt(t) + t.Cleanup(func() { + resetYubikey(t, y) + }) var slot hardwarekey.PIVSlotKeyString = "" if customSlot { @@ -89,7 +101,7 @@ func TestGetYubiKeyPrivateKey_Interactive(t *testing.T) { require.NoError(t, err) // test HardwareSigner methods - require.Equal(t, policy, priv.GetPrivateKeyPolicy()) + require.Equal(t, policy, priv.GetPrivateKeyPolicy().GetPromptPolicy()) require.NotNil(t, priv.GetAttestationStatement()) require.True(t, priv.IsHardware()) @@ -122,7 +134,10 @@ func TestOverwritePrompt(t *testing.T) { ctx := context.Background() - s := piv.NewYubiKeyService(hardwarekey.NewStdCLIPrompt()) + promptWriter := bytes.NewBuffer([]byte{}) + promptReader := prompt.NewFakeReader() + prompt := hardwarekey.NewCLIPrompt(promptWriter, promptReader) + s := piv.NewYubiKeyService(prompt) y, err := piv.FindYubiKey(0) require.NoError(t, err) @@ -135,14 +150,14 @@ func TestOverwritePrompt(t *testing.T) { testOverwritePrompt := func(t *testing.T) { // Fail to overwrite slot when user denies - prompt.SetStdin(prompt.NewFakeReader().AddString("n")) + promptReader.AddString("n") _, err := keys.NewHardwarePrivateKey(ctx, s, hardwarekey.PrivateKeyConfig{ Policy: hardwarekey.PromptPolicy{TouchRequired: true}, }) require.True(t, trace.IsCompareFailed(err), "Expected compare failed error but got %v", err) // Successfully overwrite slot when user accepts - prompt.SetStdin(prompt.NewFakeReader().AddString("y")) + promptReader.AddString("y") _, err = keys.NewHardwarePrivateKey(ctx, s, hardwarekey.PrivateKeyConfig{ Policy: hardwarekey.PromptPolicy{TouchRequired: true}, }) @@ -178,16 +193,3 @@ func resetYubikey(t *testing.T, y *piv.YubiKey) { t.Helper() require.NoError(t, y.Reset()) } - -func setupPINPrompt(t *testing.T, y *piv.YubiKey) { - t.Helper() - - // Set pin for tests. - const testPIN = "123123" - require.NoError(t, y.SetPIN(pivgo.DefaultPIN, testPIN)) - - // Handle PIN prompt. - oldStdin := prompt.Stdin() - t.Cleanup(func() { prompt.SetStdin(oldStdin) }) - prompt.SetStdin(prompt.NewFakeReader().AddString(testPIN).AddString(testPIN)) -}