diff --git a/api/utils/keys/piv/service.go b/api/utils/keys/piv/service.go index 8f2f09b81cfed..157570088e3fd 100644 --- a/api/utils/keys/piv/service.go +++ b/api/utils/keys/piv/service.go @@ -188,8 +188,31 @@ func (s *YubiKeyService) Sign(ctx context.Context, ref *hardwarekey.PrivateKeyRe return nil, trace.Wrap(err) } - s.signMu.Lock() - defer s.signMu.Unlock() + pivSlot, err := parsePIVSlot(ref.SlotKey) + if err != nil { + return nil, trace.Wrap(err) + } + + // Check that the public key in the slot matches our record. + publicKey, err := y.getPublicKey(pivSlot) + if err != nil { + return nil, trace.Wrap(err) + } + + if !publicKey.Equal(ref.PublicKey) { + return nil, trace.CompareFailed("public key mismatch on PIV slot 0x%x", pivSlot.Key) + } + + // If the sign request is for an unknown agent key, ensure that the requested PIV slot was + // configured with a self-signed Teleport metadata certificate. + if keyInfo.AgentKeyInfo.UnknownAgentKey { + switch err := y.checkCertificate(pivSlot); { + case trace.IsNotFound(err), errors.As(err, &nonTeleportCertError{}): + return nil, trace.Wrap(err, agentRequiresTeleportCertMessage) + case err != nil: + return nil, trace.Wrap(err) + } + } return y.sign(ctx, ref, keyInfo, s.getPrompt(), rand, digest, opts) } diff --git a/api/utils/keys/piv/service_test.go b/api/utils/keys/piv/service_test.go index 981707470a6d2..471a42842f7ab 100644 --- a/api/utils/keys/piv/service_test.go +++ b/api/utils/keys/piv/service_test.go @@ -22,11 +22,13 @@ import ( "crypto/x509/pkix" "fmt" "os" + "sync" "testing" "time" pivgo "github.com/go-piv/piv-go/piv" "github.com/gravitational/trace" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/utils/keys" @@ -291,6 +293,47 @@ func TestPINCaching(t *testing.T) { require.Error(t, err) } +func TestConcurrentSignature(t *testing.T) { + // This test will overwrite any PIV data on the yubiKey. + if os.Getenv("TELEPORT_TEST_YUBIKEY_PIV") == "" { + t.Skipf("Skipping TestGenerateYubiKeyPrivateKey because TELEPORT_TEST_YUBIKEY_PIV is not set") + } + + ctx := context.Background() + promptReader := prompt.NewFakeReader() + prompt := hardwarekey.NewCLIPrompt(os.Stderr, promptReader) + s := piv.NewYubiKeyService(prompt) + + y, err := piv.FindYubiKey(0) + require.NoError(t, err) + + resetYubikey(t, y) + t.Cleanup(func() { resetYubikey(t, y) }) + + // Set pin. + const testPIN = "123123" + require.NoError(t, y.SetPIN(pivgo.DefaultPIN, testPIN)) + + promptReader.AddString(testPIN) + priv, err := keys.NewHardwarePrivateKey(ctx, s, hardwarekey.PrivateKeyConfig{ + // Use PIN policy to slow down the signatures a bit so that they are concurrent. + Policy: hardwarekey.PromptPolicyPIN, + }) + require.NoError(t, err) + + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + err = priv.WarmupHardwareKey(ctx) + assert.NoError(t, err) + }() + } + + wg.Wait() +} + // resetYubikey connects to the first yubiKey and resets it to defaults. func resetYubikey(t *testing.T, y *piv.YubiKey) { t.Helper() diff --git a/api/utils/keys/piv/yubikey.go b/api/utils/keys/piv/yubikey.go index a1ba9570f6b56..bb1eccdc0e0fb 100644 --- a/api/utils/keys/piv/yubikey.go +++ b/api/utils/keys/piv/yubikey.go @@ -60,6 +60,11 @@ type YubiKey struct { version piv.Version // pinCache can be used to skip PIN prompts for keys that have PIN caching enabled. pinCache *pinCache + + // promptMu prevents prompting for PIN/touch repeatedly for concurrent signatures. + // TODO(Joerger): Rather than preventing concurrent signatures, we can make the + // PIN and touch prompts durable to concurrent signatures. + promptMu sync.Mutex } // FindYubiKey finds a YubiKey PIV card by serial number. If the provided @@ -147,174 +152,57 @@ const ( signTouchPromptDelay = time.Millisecond * 200 ) -func (y *YubiKey) sign(ctx context.Context, ref *hardwarekey.PrivateKeyRef, keyInfo hardwarekey.ContextualKeyInfo, prompt hardwarekey.Prompt, rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - pivSlot, err := parsePIVSlot(ref.SlotKey) - if err != nil { - return nil, trace.Wrap(err) - } - - // Check that the public key in the slot matches our record. - slotCert, err := y.conn.attest(pivSlot) - if err != nil { - return nil, trace.Wrap(err) - } - type cryptoPublicKeyI interface { - Equal(x crypto.PublicKey) bool - } - if slotPub, ok := slotCert.PublicKey.(cryptoPublicKeyI); !ok { - return nil, trace.BadParameter("expected crypto.PublicKey but got %T", slotCert.PublicKey) - } else if !slotPub.Equal(ref.PublicKey) { - return nil, trace.CompareFailed("public key mismatch on PIV slot 0x%x", pivSlot.Key) - } - - // If the sign request is for an unknown agent key, ensure that the requested PIV slot was - // configured with a self-signed Teleport metadata certificate. - if keyInfo.AgentKeyInfo.UnknownAgentKey { - switch err := y.checkCertificate(pivSlot); { - case trace.IsNotFound(err), errors.As(err, &nonTeleportCertError{}): - return nil, trace.Wrap(err, agentRequiresTeleportCertMessage) - case err != nil: - return nil, trace.Wrap(err) - } - } - - ctx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) +const ( + // For generic auth errors, such as when PIN is not provided, the smart card returns the error code 0x6982. + // The piv-go library wraps error codes like this with a user readable message: "security status not satisfied". + pivGenericAuthErrCodeString = "6982" +) - // Lock the connection for the entire duration of the sign - // process. Without this, the connection will be released, - // leading to a failure when providing PIN or touch input: - // "verify pin: transmitting request: the supplied handle was invalid". - release, err := y.conn.connect() - if err != nil { - return nil, trace.Wrap(err) +func (y *YubiKey) sign(ctx context.Context, ref *hardwarekey.PrivateKeyRef, keyInfo hardwarekey.ContextualKeyInfo, prompt hardwarekey.Prompt, rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + // When using [piv.PINPolicyOnce], PIN is only required when it isn't cached in the PCSC + // transaction internally. The piv-go prompt logic attempts to check this requirement + // before prompting, which is generally workable. However, the PIN prompt logic is not + // flexible enough for the retry and PIN caching mechanisms supported in Teleport. As a + // result, we must first try signature without PIN and only prompt for PIN when we get a + // "security status not satisfied" error ([pivGenericAuthErrCodeString]). + // + // TODO(Joerger): Once https://github.com/go-piv/piv-go/pull/174 is merged upstream, we can + // check if PIN is required and verify PIN before attempting the signature. This is a more + // reliable method of checking the PIN requirement than the somewhat general auth error + // returned by the failed signature. + // IMPORTANT: Maintain the signature retry flow for firmware version 5.3.1, which has a bug + // with checking the PIN requirement - https://github.com/gravitational/teleport/pull/36427. + auth := piv.KeyAuth{ + PINPolicy: piv.PINPolicyNever, } - defer release() - var touchPromptDelayTimer *time.Timer + var promptTouch promptTouch if ref.Policy.TouchRequired { - touchPromptDelayTimer = time.NewTimer(signTouchPromptDelay) - defer touchPromptDelayTimer.Stop() - - go func() { - select { - case <-touchPromptDelayTimer.C: - // Prompt for touch after a delay, in case the function succeeds without touch due to a cached touch. - err := prompt.Touch(ctx, keyInfo) - if err != nil { - // Cancel the entire function when an error occurs. - // This is typically used for aborting the prompt. - cancel(trace.Wrap(err)) - } - return - case <-ctx.Done(): - // touch cached, skip prompt. - return - } - }() - } - - promptPIN := func() (string, error) { - // touch prompt delay is disrupted by pin prompts. To prevent misfired - // touch prompts, pause the timer for the duration of the pin prompt. - if touchPromptDelayTimer != nil { - if touchPromptDelayTimer.Stop() { - defer touchPromptDelayTimer.Reset(signTouchPromptDelay) - } + promptTouch = func(ctx context.Context) error { + return y.promptTouch(ctx, prompt, keyInfo) } - - return y.promptPIN(ctx, prompt, hardwarekey.PINRequired, keyInfo, ref.PINCacheTTL) } - pinPolicy := piv.PINPolicyNever - if ref.Policy.PINRequired { - pinPolicy = piv.PINPolicyOnce - } - - auth := piv.KeyAuth{ - PINPrompt: promptPIN, - PINPolicy: pinPolicy, - } - - // YubiKeys with firmware version 5.3.1 have a bug where insVerify(0x20, 0x00, 0x80, nil) - // clears the PIN cache instead of performing a non-mutable check. This causes the signature - // with pin policy "once" to fail unless PIN is provided for each call. We can avoid this bug - // by skipping the insVerify check and instead manually retrying with a PIN prompt only when - // the signature fails. - manualRetryWithPIN := false - fw531 := piv.Version{Major: 5, Minor: 3, Patch: 1} - if auth.PINPolicy == piv.PINPolicyOnce && y.conn.conn.Version() == fw531 { - // Set the keys PIN policy to never to skip the insVerify check. If PIN was provided in - // a previous recent call, the signature will succeed as expected of the "once" policy. - auth.PINPolicy = piv.PINPolicyNever - manualRetryWithPIN = true - } - - privateKey, err := y.conn.privateKey(pivSlot, ref.PublicKey, auth) - if err != nil { - return nil, trace.Wrap(err) - } - - signer, ok := privateKey.(crypto.Signer) - if !ok { - return nil, trace.BadParameter("private key type %T does not implement crypto.Signer", privateKey) - } - - // For generic auth errors, such as when PIN is not provided, the smart card returns the error code 0x6982. - // The piv-go library wraps error codes like this with a user readable message: "security status not satisfied". - const pivGenericAuthErrCodeString = "6982" - - signature, err := abandonableSign(ctx, signer, rand, digest, opts) + signature, err := y.conn.sign(ctx, ref, auth, promptTouch, rand, digest, opts) switch { case err == nil: return signature, nil - case manualRetryWithPIN && strings.Contains(err.Error(), pivGenericAuthErrCodeString): - pin, err := promptPIN() + case strings.Contains(err.Error(), pivGenericAuthErrCodeString) && ref.Policy.PINRequired: + pin, err := y.promptPIN(ctx, prompt, hardwarekey.PINRequired, keyInfo, ref.PINCacheTTL) if err != nil { return nil, trace.Wrap(err) } - if err := y.conn.verifyPIN(pin); err != nil { - return nil, trace.Wrap(err) - } - signature, err := abandonableSign(ctx, signer, rand, digest, opts) - return signature, trace.Wrap(err) + + // Setting the [piv.PINPolicyAlways] ensures that the PIN is used and skips + // the required check usually used with [piv.PINPolicyOnce]. + auth.PINPolicy = piv.PINPolicyAlways + auth.PIN = pin + return y.conn.sign(ctx, ref, auth, promptTouch, rand, digest, opts) default: return nil, trace.Wrap(err) } } -// abandonableSign is a wrapper around signer.Sign. -// It enhances the functionality of signer.Sign by allowing the caller to stop -// waiting for the result if the provided context is canceled. -// It is especially important for WarmupHardwareKey, -// where waiting for the user providing a PIN/touch could block program termination. -// Important: this function only abandons the signer.Sign result, doesn't cancel it. -func abandonableSign(ctx context.Context, signer crypto.Signer, rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { - type signResult struct { - signature []byte - err error - } - - signResultCh := make(chan signResult) - go func() { - if err := ctx.Err(); err != nil { - return - } - signature, err := signer.Sign(rand, digest, opts) - select { - case <-ctx.Done(): - case signResultCh <- signResult{signature: signature, err: trace.Wrap(err)}: - } - }() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case result := <-signResultCh: - return result.signature, trace.Wrap(result.err) - } -} - // Reset resets the YubiKey PIV module to default settings. func (y *YubiKey) Reset() error { err := y.conn.reset() @@ -407,6 +295,25 @@ func (y *YubiKey) checkCertificate(slot piv.Slot) error { return nil } +type cryptoPublicKey interface { + Equal(x crypto.PublicKey) bool +} + +// getPublicKey gets a public key from the given PIV slot. +func (y *YubiKey) getPublicKey(slot piv.Slot) (cryptoPublicKey, error) { + slotCert, err := y.conn.attest(slot) + if err != nil { + return nil, trace.Wrap(err, "failed to get slot cert on PIV slot 0x%x", slot.Key) + } + + slotPub, ok := slotCert.PublicKey.(cryptoPublicKey) + if !ok { + return nil, trace.BadParameter("expected crypto.PublicKey but got %T", slotCert.PublicKey) + } + + return slotPub, nil +} + // attestKey attests the key in the given PIV slot. // The key's public key can be found in the returned slotCert. func (y *YubiKey) attestKey(slot piv.Slot) (slotCert *x509.Certificate, attCert *x509.Certificate, att *piv.Attestation, err error) { @@ -489,6 +396,7 @@ func (y *YubiKey) checkOrSetPIN(ctx context.Context, prompt hardwarekey.Prompt, // the pin cache mutex or the exclusive PC/SC transaction. const pinPromptTimeout = time.Minute +// promptPIN prompts for PIN. If PIN caching is enabled, it verifies and caches the PIN for future calls. func (y *YubiKey) promptPIN(ctx context.Context, prompt hardwarekey.Prompt, requirement hardwarekey.PINPromptRequirement, keyInfo hardwarekey.ContextualKeyInfo, pinCacheTTL time.Duration) (string, error) { y.pinCache.mu.Lock() defer y.pinCache.mu.Unlock() @@ -501,6 +409,9 @@ func (y *YubiKey) promptPIN(ctx context.Context, prompt hardwarekey.Prompt, requ ctx, cancel := context.WithTimeout(ctx, pinPromptTimeout) defer cancel() + y.promptMu.Lock() + defer y.promptMu.Unlock() + pin, err := prompt.AskPIN(ctx, requirement, keyInfo) if err != nil { return "", trace.Wrap(err) @@ -517,6 +428,13 @@ func (y *YubiKey) promptPIN(ctx context.Context, prompt hardwarekey.Prompt, requ return pin, nil } +func (y *YubiKey) promptTouch(ctx context.Context, prompt hardwarekey.Prompt, keyInfo hardwarekey.ContextualKeyInfo) error { + y.promptMu.Lock() + defer y.promptMu.Unlock() + + return prompt.Touch(ctx, keyInfo) +} + func (y *YubiKey) setPINAndPUKFromDefault(ctx context.Context, prompt hardwarekey.Prompt, keyInfo hardwarekey.ContextualKeyInfo, pinCacheTTL time.Duration) (string, error) { y.pinCache.mu.Lock() defer y.pinCache.mu.Unlock() @@ -562,6 +480,10 @@ type sharedPIVConnection struct { conn *piv.YubiKey mu sync.Mutex activeConnections int + + // exclusiveOperationMu is used to ensure that PIV operations that don't + // support concurrency are not run concurrently. + exclusiveOperationMu sync.RWMutex } // connect establishes a connection to a YubiKey PIV module and returns a release function. @@ -609,6 +531,11 @@ func (c *sharedPIVConnection) connect() (func(), error) { retryCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + isRetryError := func(err error) bool { + const retryError = "connecting to smart card: the smart card cannot be accessed because of other connections outstanding" + return strings.Contains(err.Error(), retryError) + } + err = linearRetry.For(retryCtx, func() error { c.conn, err = piv.Open(c.card) if err != nil && !isRetryError(err) { @@ -634,14 +561,89 @@ func (c *sharedPIVConnection) connect() (func(), error) { return release, nil } -func (c *sharedPIVConnection) privateKey(slot piv.Slot, public crypto.PublicKey, auth piv.KeyAuth) (crypto.PrivateKey, error) { +type promptTouch func(ctx context.Context) error + +func (c *sharedPIVConnection) sign(ctx context.Context, ref *hardwarekey.PrivateKeyRef, auth piv.KeyAuth, promptTouch promptTouch, rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + pivSlot, err := parsePIVSlot(ref.SlotKey) + if err != nil { + return nil, trace.Wrap(err) + } + release, err := c.connect() if err != nil { return nil, trace.Wrap(err) } defer release() - privateKey, err := c.conn.PrivateKey(slot, public, auth) - return privateKey, trace.Wrap(err) + + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + + // Prepare the key and perform the signature with the same connection. + // Closing the connection in between breaks the underlying PIV handle. + priv, err := c.conn.PrivateKey(pivSlot, ref.PublicKey, auth) + if err != nil { + return nil, trace.Wrap(err) + } + + signer, ok := priv.(crypto.Signer) + if !ok { + return nil, trace.BadParameter("private key type %T does not implement crypto.Signer", priv) + } + + return abandonableSign(ctx, signer, promptTouch, rand, digest, opts) +} + +// abandonableSign extends [sharedPIVConnection.sign] to handle context, allowing the +// caller to stop waiting for the result if the provided context is canceled. +// +// This is necessary for hardware key signatures which sometimes require touch from the +// user to complete, which can block program termination. +// +// Important: this function only abandons the signer.Sign result, doesn't cancel it. +func abandonableSign(ctx context.Context, signer crypto.Signer, promptTouch promptTouch, rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Since this function isn't fully synchronous, the goroutines below may outlive + // the function call, especially sign which cannot be stopped once started. We + // use buffered channels to ensure these goroutines can send even with no receiver + // to avoid leaking. + signatureC := make(chan []byte, 1) + errC := make(chan error, 2) + + go func() { + signature, err := signer.Sign(rand, digest, opts) + if err != nil { + errC <- err + return + } + signatureC <- signature + }() + + if promptTouch != nil { + go func() { + // There is no built in mechanism to prompt for touch on demand, so we simply prompt for touch after + // a short duration in hopes of lining up with the actual YubiKey touch prompt (flashing key). In the + // case where touch is cached, the delay prevents the prompt from firing when it isn't needed. + select { + case <-time.After(signTouchPromptDelay): + if err := promptTouch(ctx); err != nil { + errC <- promptTouch(ctx) + } + case <-ctx.Done(): + // prompt cached or signature canceled, skip prompt. + } + }() + } + + select { + case <-ctx.Done(): + return nil, trace.Wrap(ctx.Err()) + case err := <-errC: + return nil, trace.Wrap(err) + case signature := <-signatureC: + return signature, nil + } } func (c *sharedPIVConnection) getSerialNumber() (uint32, error) { @@ -650,6 +652,10 @@ func (c *sharedPIVConnection) getSerialNumber() (uint32, error) { return 0, trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + serial, err := c.conn.Serial() return serial, trace.Wrap(err) } @@ -660,6 +666,8 @@ func (c *sharedPIVConnection) getVersion() (piv.Version, error) { return piv.Version{}, trace.Wrap(err) } defer release() + + // Version only requires an open connection, so we don't need to lock on [c.exclusiveOperationMu]. return c.conn.Version(), nil } @@ -669,6 +677,10 @@ func (c *sharedPIVConnection) reset() error { return trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + return trace.Wrap(c.conn.Reset()) } @@ -678,6 +690,10 @@ func (c *sharedPIVConnection) setCertificate(key [24]byte, slot piv.Slot, cert * return trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + return trace.Wrap(c.conn.SetCertificate(key, slot, cert)) } @@ -687,6 +703,10 @@ func (c *sharedPIVConnection) certificate(slot piv.Slot) (*x509.Certificate, err return nil, trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + cert, err := c.conn.Certificate(slot) return cert, trace.Wrap(err) } @@ -697,6 +717,10 @@ func (c *sharedPIVConnection) generateKey(key [24]byte, slot piv.Slot, opts piv. return nil, trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + pubKey, err := c.conn.GenerateKey(key, slot, opts) return pubKey, trace.Wrap(err) } @@ -707,6 +731,10 @@ func (c *sharedPIVConnection) attest(slot piv.Slot) (*x509.Certificate, error) { return nil, trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + cert, err := c.conn.Attest(slot) return cert, trace.Wrap(err) } @@ -717,6 +745,10 @@ func (c *sharedPIVConnection) attestationCertificate() (*x509.Certificate, error return nil, trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.Lock() + defer c.exclusiveOperationMu.Unlock() + cert, err := c.conn.AttestationCertificate() return cert, trace.Wrap(err) } @@ -727,6 +759,10 @@ func (c *sharedPIVConnection) setPIN(oldPIN string, newPIN string) error { return trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + return trace.Wrap(c.conn.SetPIN(oldPIN, newPIN)) } @@ -736,6 +772,10 @@ func (c *sharedPIVConnection) setPUK(oldPUK string, newPUK string) error { return trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + return trace.Wrap(c.conn.SetPUK(oldPUK, newPUK)) } @@ -745,6 +785,10 @@ func (c *sharedPIVConnection) unblock(puk string, newPIN string) error { return trace.Wrap(err) } defer release() + + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + return trace.Wrap(c.conn.Unblock(puk, newPIN)) } @@ -754,12 +798,11 @@ func (c *sharedPIVConnection) verifyPIN(pin string) error { return trace.Wrap(err) } defer release() - return trace.Wrap(c.conn.VerifyPIN(pin)) -} -func isRetryError(err error) bool { - const retryError = "connecting to smart card: the smart card cannot be accessed because of other connections outstanding" - return strings.Contains(err.Error(), retryError) + c.exclusiveOperationMu.RLock() + defer c.exclusiveOperationMu.RUnlock() + + return trace.Wrap(c.conn.VerifyPIN(pin)) } func parsePIVSlot(slotKey hardwarekey.PIVSlotKey) (piv.Slot, error) { diff --git a/api/utils/keys/piv/yubikey_test.go b/api/utils/keys/piv/yubikey_test.go new file mode 100644 index 0000000000000..0715910564efa --- /dev/null +++ b/api/utils/keys/piv/yubikey_test.go @@ -0,0 +1,139 @@ +//go:build piv + +// Copyright 2025 Gravitational, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package piv + +import ( + "context" + "crypto" + "crypto/x509/pkix" + "os" + "sync" + "testing" + + "github.com/go-piv/piv-go/piv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/utils/keys/hardwarekey" +) + +func TestConcurrentOperations(t *testing.T) { + // This test will overwrite any PIV data on the yubiKey. + if os.Getenv("TELEPORT_TEST_YUBIKEY_PIV") == "" { + t.Skipf("Skipping TestGenerateYubiKeyPrivateKey because TELEPORT_TEST_YUBIKEY_PIV is not set") + } + + y, err := FindYubiKey(0) + require.NoError(t, err) + + y.Reset() + t.Cleanup(func() { y.Reset() }) + + usedSlot := piv.SlotAuthentication + ref, err := y.generatePrivateKey(usedSlot, hardwarekey.PromptPolicyNone, hardwarekey.SignatureAlgorithmEC256, 0) + require.NoError(t, err) + require.NotNil(t, ref) + + unusedSlot := piv.SlotCardAuthentication + cert, err := SelfSignedMetadataCertificate(pkix.Name{}) + require.NoError(t, err) + + // Run each PIV command several times concurrently to ensure the concurrency + // protections in place properly protect each operations, especially those + // which do not support concurrency. + var wg sync.WaitGroup + for range 5 { + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.getSerialNumber() + assert.NoError(t, err, "getSerialNumber") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.sign(context.Background(), ref, piv.KeyAuth{PINPolicy: piv.PINPolicyNever}, nil, nil, make([]byte, 100), crypto.Hash(0)) + assert.NoError(t, err, "sign") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.getVersion() + assert.NoError(t, err, "getVersion") + }() + wg.Add(1) + go func() { + defer wg.Done() + err := y.conn.setCertificate(piv.DefaultManagementKey, unusedSlot, cert) + assert.NoError(t, err, "setCertificate") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.certificate(usedSlot) + assert.NoError(t, err, "certificate") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.generateKey(piv.DefaultManagementKey, unusedSlot, piv.Key{ + Algorithm: piv.AlgorithmEC256, + TouchPolicy: piv.TouchPolicyNever, + PINPolicy: piv.PINPolicyNever, + }) + assert.NoError(t, err, "generateKey") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.attest(usedSlot) + assert.NoError(t, err, "attest") + }() + wg.Add(1) + go func() { + defer wg.Done() + _, err := y.conn.attestationCertificate() + assert.NoError(t, err, "attestationCertificate") + }() + wg.Add(1) + go func() { + defer wg.Done() + err := y.conn.setPIN(piv.DefaultPIN, piv.DefaultPIN) + assert.NoError(t, err, "setPIN") + }() + wg.Add(1) + go func() { + defer wg.Done() + err := y.conn.setPUK(piv.DefaultPUK, piv.DefaultPUK) + assert.NoError(t, err, "setPUK") + }() + wg.Add(1) + go func() { + defer wg.Done() + err := y.conn.unblock(piv.DefaultPUK, piv.DefaultPIN) + assert.NoError(t, err, "unblock") + }() + wg.Add(1) + go func() { + defer wg.Done() + err := y.conn.verifyPIN(piv.DefaultPIN) + assert.NoError(t, err, "verifyPIN") + }() + } + + wg.Wait() +}