diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 17f0671926f55..e6c5849492004 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -5265,15 +5265,20 @@ func NewGRPCServer(cfg GRPCServerConfig) (*GRPCServer, error) { workloadidentityv1pb.RegisterWorkloadIdentityIssuanceServiceServer(server, workloadIdentityIssuanceService) workloadIdentityRevocationService, err := workloadidentityv1.NewRevocationService(&workloadidentityv1.RevocationServiceConfig{ - Authorizer: cfg.Authorizer, - Emitter: cfg.Emitter, - Clock: cfg.AuthServer.GetClock(), - Store: cfg.AuthServer.Services.WorkloadIdentityX509Revocations, + Authorizer: cfg.Authorizer, + Emitter: cfg.Emitter, + Clock: cfg.AuthServer.GetClock(), + Store: cfg.AuthServer.Services.WorkloadIdentityX509Revocations, + KeyStore: cfg.AuthServer.keyStore, + CertAuthorityGetter: cfg.AuthServer.Cache, + EventsWatcher: cfg.AuthServer.Services, + ClusterName: clusterName.GetClusterName(), }) if err != nil { - return nil, trace.Wrap(err, "creating workload identity issuance service") + return nil, trace.Wrap(err, "creating workload identity revocation service") } workloadidentityv1pb.RegisterWorkloadIdentityRevocationServiceServer(server, workloadIdentityRevocationService) + go workloadIdentityRevocationService.RunCRLSigner(cfg.AuthServer.CloseContext()) dbObjectImportRuleService, err := dbobjectimportrulev1.NewDatabaseObjectImportRuleService(dbobjectimportrulev1.DatabaseObjectImportRuleServiceConfig{ Authorizer: cfg.Authorizer, diff --git a/lib/auth/machineid/workloadidentityv1/revocation_service.go b/lib/auth/machineid/workloadidentityv1/revocation_service.go index 656ad53efba2a..83f80c2cae72f 100644 --- a/lib/auth/machineid/workloadidentityv1/revocation_service.go +++ b/lib/auth/machineid/workloadidentityv1/revocation_service.go @@ -18,7 +18,12 @@ package workloadidentityv1 import ( "context" + "crypto/rand" + "crypto/x509" "log/slog" + "math/big" + "sync" + "time" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -26,10 +31,13 @@ import ( "github.com/gravitational/teleport" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" + "github.com/gravitational/teleport/api/observability/tracing" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/authz" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/tlsca" ) type workloadIdentityX509RevocationReadWriter interface { @@ -41,6 +49,14 @@ type workloadIdentityX509RevocationReadWriter interface { UpsertWorkloadIdentityX509Revocation(ctx context.Context, resource *workloadidentityv1pb.WorkloadIdentityX509Revocation) (*workloadidentityv1pb.WorkloadIdentityX509Revocation, error) } +type certAuthorityGetter interface { + GetCertAuthority(ctx context.Context, id types.CertAuthID, loadKeys bool) (types.CertAuthority, error) +} + +type eventsWatcher interface { + NewWatcher(ctx context.Context, watch types.Watch) (types.Watcher, error) +} + // RevocationServiceConfig holds configuration options for the RevocationService. type RevocationServiceConfig struct { Authorizer authz.Authorizer @@ -48,6 +64,15 @@ type RevocationServiceConfig struct { Clock clockwork.Clock Emitter apievents.Emitter Logger *slog.Logger + // CertAuthorityGetter is used to fetch the CA for signing the CRL. + CertAuthorityGetter certAuthorityGetter + EventsWatcher eventsWatcher + // ClusterName is the name of the cluster that the service is running in, + // used to fetch the correct CA for signing the CRL. + ClusterName string + // KeyStore is the key storer used to store and retrieve keys for the + // signing of the CRL. + KeyStore KeyStorer } // RevocationService is the gRPC service for managing workload identity @@ -56,11 +81,26 @@ type RevocationServiceConfig struct { type RevocationService struct { workloadidentityv1pb.UnimplementedWorkloadIdentityRevocationServiceServer - authorizer authz.Authorizer - store workloadIdentityX509RevocationReadWriter - clock clockwork.Clock - emitter apievents.Emitter - logger *slog.Logger + authorizer authz.Authorizer + store workloadIdentityX509RevocationReadWriter + clock clockwork.Clock + emitter apievents.Emitter + logger *slog.Logger + certAuthorityGetter certAuthorityGetter + keyStore KeyStorer + clusterName string + eventsWatcher eventsWatcher + + crlSigningDebounce time.Duration + crlFailureBackoff time.Duration + crlPeriodicRenewal time.Duration + + // mu protects the signedCRL and notifyNewSignedCRL field. + mu sync.Mutex + signedCRL []byte + // notifyNewCRL will be closed when a new CRL is available. It is protected + // by mu. + notifyNewSignedCRL chan struct{} } // NewRevocationService returns a new instance of the RevocationService. @@ -72,6 +112,12 @@ func NewRevocationService(cfg *RevocationServiceConfig) (*RevocationService, err return nil, trace.BadParameter("authorizer is required") case cfg.Emitter == nil: return nil, trace.BadParameter("emitter is required") + case cfg.ClusterName == "": + return nil, trace.BadParameter("cluster name is required") + case cfg.KeyStore == nil: + return nil, trace.BadParameter("key storer is required") + case cfg.EventsWatcher == nil: + return nil, trace.BadParameter("events watcher is required") } if cfg.Logger == nil { @@ -81,11 +127,20 @@ func NewRevocationService(cfg *RevocationServiceConfig) (*RevocationService, err cfg.Clock = clockwork.NewRealClock() } return &RevocationService{ - authorizer: cfg.Authorizer, - store: cfg.Store, - clock: cfg.Clock, - emitter: cfg.Emitter, - logger: cfg.Logger, + authorizer: cfg.Authorizer, + store: cfg.Store, + clock: cfg.Clock, + emitter: cfg.Emitter, + logger: cfg.Logger, + clusterName: cfg.ClusterName, + certAuthorityGetter: cfg.CertAuthorityGetter, + keyStore: cfg.KeyStore, + eventsWatcher: cfg.EventsWatcher, + crlSigningDebounce: 5 * time.Second, + crlFailureBackoff: 30 * time.Second, + crlPeriodicRenewal: 10 * time.Minute, + + notifyNewSignedCRL: make(chan struct{}), }, nil } @@ -323,3 +378,298 @@ func (s *RevocationService) UpsertWorkloadIdentityX509Revocation( return created, nil } + +// StreamSignedCRL streams the signed CRL to the client. If the CRL has not +// yet been signed, the server will wait until it has been signed to send it +// to the client. +// Implements teleport.workloadidentity.v1.RevocationService/StreamSignedCRL +func (s *RevocationService) StreamSignedCRL( + req *workloadidentityv1pb.StreamSignedCRLRequest, + srv workloadidentityv1pb.WorkloadIdentityRevocationService_StreamSignedCRLServer, +) error { + for { + crl, notify := s.getSignedCRL() + + // The CRL may not yet have been signed, so, skip straight to waiting + // for an update. + if len(crl) != 0 { + if err := srv.Send(&workloadidentityv1pb.StreamSignedCRLResponse{ + Crl: crl, + }); err != nil { + return trace.Wrap(err) + } + } + + select { + case <-notify: + case <-srv.Context().Done(): + return nil + } + } +} + +func (s *RevocationService) RunCRLSigner(ctx context.Context) { + for { + err := s.watchAndSign(ctx) + if err == nil { + if ctx.Err() != nil { + return + } + err = trace.BadParameter("watchAndSign exited unexpectedly") + } + retryAfter := retryutils.NewHalfJitter()(s.crlFailureBackoff) + if err != nil { + s.logger.ErrorContext( + ctx, + "CRL signer exited with error", + "error", err, + "retry_after", retryAfter, + ) + } + + select { + case <-ctx.Done(): + return + case <-s.clock.After(retryAfter): + s.logger.DebugContext(ctx, "Retry backoff expired, restarting CRL signer") + } + } + +} + +func (s *RevocationService) watchAndSign(ctx context.Context) error { + s.logger.DebugContext(ctx, "Starting CRL signer") + w, err := s.eventsWatcher.NewWatcher(ctx, types.Watch{ + Kinds: []types.WatchKind{{ + Kind: types.KindWorkloadIdentityX509Revocation, + }}, + }) + if err != nil { + return trace.Wrap(err, "creating events watcher") + } + defer func() { + if err := w.Close(); err != nil { + s.logger.WarnContext(ctx, "Failed to close watcher", "error", err) + } + }() + + // Wait for initial "Init" event to indicate we're now receiving events. + select { + case <-w.Done(): + if err := w.Error(); err != nil { + return trace.Wrap(err, "watcher failed") + } + return nil + case evt := <-w.Events(): + if evt.Type == types.OpInit { + break + } + return trace.BadParameter("expected init event, got %v", evt.Type) + case <-ctx.Done(): + return nil + } + + revocationsSlice, err := s.fetchAllRevocations(ctx) + if err != nil { + return trace.Wrap(err, "initially fetching revocations") + } + revocationsMap := make(map[string]*workloadidentityv1pb.WorkloadIdentityX509Revocation, len(revocationsSlice)) + for _, revocation := range revocationsSlice { + revocationsMap[revocation.Metadata.Name] = revocation + } + + handleEvent := func(e types.Event) (bool, error) { + switch e.Type { + case types.OpPut: + unwrapper, ok := e.Resource.(types.Resource153Unwrapper) + if !ok { + return false, trace.BadParameter( + "expected event resource (%s) to implement Resource153Wrapper", + e.Resource.GetName(), + ) + } + unwrapped := unwrapper.Unwrap() + revocation, ok := unwrapped.(*workloadidentityv1pb.WorkloadIdentityX509Revocation) + if !ok { + return false, trace.BadParameter( + "expected event resource (%s) to be a WorkloadIdentityX509Revocation, but it was %T", + e.Resource.GetName(), + unwrapped, + ) + } + revocationsMap[revocation.Metadata.Name] = revocation + return true, nil + case types.OpDelete: + delete(revocationsMap, e.Resource.GetName()) + return true, nil + default: + } + return false, nil + } + + // Perform initial signing of the CRL + crl, err := s.signCRL(ctx, revocationsMap) + if err != nil { + return trace.Wrap(err, "signing initial CRL") + } + s.publishSignedCRL(crl) + s.logger.DebugContext(ctx, "Finished initializing CRL signer, watching for revocation events") + + // A short, simple debounce so that we: + // - Avoid signing the CRL too frequently. This is computationally + // expensive and we can afford to wait a few seconds to group together + // multiple successive revocations. + // - Avoid spamming the clients with a rapid succession of CRL updates. + var debounceCh <-chan time.Time + for { + periodic := s.clock.NewTimer(s.crlPeriodicRenewal) + select { + case e := <-w.Events(): + triggerSign, err := handleEvent(e) + if err != nil { + return trace.Wrap(err, "handling event") + } + if triggerSign { + s.logger.DebugContext(ctx, "Received change to WorkloadIdentityX509Revocation indicating new CRL should be signed", "workload_identity_revocation_name", e.Resource.GetName()) + if debounceCh == nil { + s.logger.DebugContext(ctx, "Starting debounce timer for signing of new CRL") + debounceCh = s.clock.After(s.crlSigningDebounce) + } + } + continue + case <-w.Done(): + if err := w.Error(); err != nil { + return trace.Wrap(err, "watcher failed") + } + return nil + case <-ctx.Done(): + return nil + case <-debounceCh: + // Set debounce channel to nil to indicate that the requested + // signature has been handled. + debounceCh = nil + + crl, err := s.signCRL(ctx, revocationsMap) + if err != nil { + return trace.Wrap(err, "signing CRL") + } + s.publishSignedCRL(crl) + case <-periodic.Chan(): + revocationsSlice, err := s.fetchAllRevocations(ctx) + if err != nil { + return trace.Wrap(err, "initially fetching revocations") + } + newRevocationsMap := make(map[string]*workloadidentityv1pb.WorkloadIdentityX509Revocation, len(revocationsSlice)) + for _, revocation := range revocationsSlice { + newRevocationsMap[revocation.Metadata.Name] = revocation + } + revocationsMap = newRevocationsMap + crl, err := s.signCRL(ctx, revocationsMap) + if err != nil { + return trace.Wrap(err, "signing CRL") + } + s.publishSignedCRL(crl) + } + } +} + +func (s *RevocationService) fetchAllRevocations(ctx context.Context) ([]*workloadidentityv1pb.WorkloadIdentityX509Revocation, error) { + pageToken := "" + revocations := []*workloadidentityv1pb.WorkloadIdentityX509Revocation{} + for { + res, token, err := s.store.ListWorkloadIdentityX509Revocations( + ctx, 0, pageToken, + ) + if err != nil { + return nil, trace.Wrap(err) + } + revocations = append(revocations, res...) + if token == "" { + break + } + pageToken = token + } + return revocations, nil +} + +// signCRL signs a new revocation list for the given set of revocations, and +// returns this as a PKCS.1 DER encoded CRL. +func (s *RevocationService) signCRL( + ctx context.Context, + revocations map[string]*workloadidentityv1pb.WorkloadIdentityX509Revocation, +) (_ []byte, err error) { + ctx, span := tracer.Start(ctx, "RevocationService/signCRL") + defer func() { + tracing.EndSpan(span, err) + }() + + s.logger.InfoContext(ctx, "Starting to generate new CRL") + ca, err := s.certAuthorityGetter.GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: s.clusterName, + }, true) + if err != nil { + return nil, trace.Wrap(err, "getting CA") + } + tlsCert, tlsSigner, err := s.keyStore.GetTLSCertAndSigner(ctx, ca) + if err != nil { + return nil, trace.Wrap(err, "getting CA cert and key") + } + tlsCA, err := tlsca.FromCertAndSigner(tlsCert, tlsSigner) + if err != nil { + return nil, trace.Wrap(err, "creating TLS CA") + } + + // RFC 5280 Certificate Revocation List + // Ref: https://datatracker.ietf.org/doc/html/rfc5280#section-5 + tmpl := &x509.RevocationList{ + // Ref: https://www.rfc-editor.org/rfc/rfc5280.html#section-5.1.2.6 + RevokedCertificateEntries: make([]x509.RevocationListEntry, 0, len(revocations)), + // Ref: https://www.rfc-editor.org/rfc/rfc5280.html#section-5.2.3 + // This is an optional extension we will be omitting for now, at a + // future date, we may insert a monotonically increasing identifier. + Number: big.NewInt(s.clock.Now().Unix()), + } + + for _, revocation := range revocations { + serial := new(big.Int) + _, ok := serial.SetString(revocation.Metadata.Name, 16) + if !ok { + s.logger.WarnContext( + ctx, + "Encountered WorkloadIdentityX509Revocation with unparsable serial number, it will be omitted from the CRL", + "workload_identity_revocation_name", revocation.Metadata.Name, + ) + continue + } + + tmpl.RevokedCertificateEntries = append(tmpl.RevokedCertificateEntries, x509.RevocationListEntry{ + SerialNumber: serial, + RevocationTime: revocation.Spec.RevokedAt.AsTime(), + }) + } + + signedCRL, err := x509.CreateRevocationList( + rand.Reader, tmpl, tlsCA.Cert, tlsCA.Signer, + ) + if err != nil { + return nil, trace.Wrap(err) + } + s.logger.InfoContext(ctx, "Finished generating new CRL", "revocations", len(revocations)) + return signedCRL, nil +} + +func (s *RevocationService) getSignedCRL() ([]byte, chan struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + return s.signedCRL, s.notifyNewSignedCRL +} + +func (s *RevocationService) publishSignedCRL(crl []byte) { + s.mu.Lock() + defer s.mu.Unlock() + s.signedCRL = crl + // Close old channel to notify clients that a new CRL is available. + close(s.notifyNewSignedCRL) + s.notifyNewSignedCRL = make(chan struct{}) +} diff --git a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go index 349c78e4fbfb3..6e9e71e7ad110 100644 --- a/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go +++ b/lib/auth/machineid/workloadidentityv1/workloadidentityv1_test.go @@ -25,6 +25,7 @@ import ( "crypto/x509/pkix" "errors" "fmt" + "math/big" "net" "os" "slices" @@ -2991,3 +2992,166 @@ func TestRevocationService_UpsertWorkloadIdentityX509Revocation(t *testing.T) { }) } } + +func TestRevocationService_CRL(t *testing.T) { + t.Parallel() + srv, _ := newTestTLSServer(t) + ctx := context.Background() + fakeClock := srv.Clock().(clockwork.FakeClock) + + authorizedUser, _, err := auth.CreateUserAndRole( + srv.Auth(), + "authorized", + []string{}, + []types.Rule{ + { + Resources: []string{types.KindWorkloadIdentityX509Revocation}, + Verbs: []string{ + types.VerbRead, + types.VerbList, + types.VerbCreate, + types.VerbUpdate, + types.VerbDelete, + }, + }, + }) + require.NoError(t, err) + authorizedClient, err := srv.NewClient(auth.TestUser(authorizedUser.GetName())) + require.NoError(t, err) + revocationsClient := authorizedClient.WorkloadIdentityRevocationServiceClient() + + // Fetch the SPIFFE CA so we can validate CRL signature. + ca, err := srv.Auth().GetCertAuthority(ctx, types.CertAuthID{ + Type: types.SPIFFECA, + DomainName: srv.ClusterName(), + }, false) + require.NoError(t, err) + caCert, err := tlsca.ParseCertificatePEM(ca.GetActiveKeys().TLS[0].Cert) + require.NoError(t, err) + + checkCRL := func( + t *testing.T, + crlBytes []byte, + wantEntries []x509.RevocationListEntry, + ) { + require.NotEmpty(t, crlBytes) + + // Expect a DER encoded CRL directly (e.g no PEM) + parsed, err := x509.ParseRevocationList(crlBytes) + require.NoError(t, err) + + // Check CRL has a valid signature + require.NoError(t, parsed.CheckSignatureFrom(caCert)) + + diff := cmp.Diff( + wantEntries, + parsed.RevokedCertificateEntries, + cmp.Comparer(func(a, b *big.Int) bool { + return a.Cmp(b) == 0 + }), + cmpopts.IgnoreFields(x509.RevocationListEntry{}, "Raw"), + cmpopts.SortSlices(func(a, b x509.RevocationListEntry) bool { + return a.SerialNumber.Cmp(b.SerialNumber) < 0 + }), + ) + require.Empty(t, diff) + } + + revokedAt := srv.Clock().Now() + createRevocation := func(t *testing.T, name string) { + _, err = revocationsClient.CreateWorkloadIdentityX509Revocation( + ctx, + &workloadidentityv1pb.CreateWorkloadIdentityX509RevocationRequest{ + WorkloadIdentityX509Revocation: &workloadidentityv1pb.WorkloadIdentityX509Revocation{ + Kind: types.KindWorkloadIdentityX509Revocation, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: name, + Expires: timestamppb.New(srv.Clock().Now().Add(time.Hour)), + }, + Spec: &workloadidentityv1pb.WorkloadIdentityX509RevocationSpec{ + Reason: "compromised", + RevokedAt: timestamppb.New(revokedAt), + }, + }, + }, + ) + require.NoError(t, err) + } + deleteRevocation := func(t *testing.T, name string) { + _, err = revocationsClient.DeleteWorkloadIdentityX509Revocation( + ctx, + &workloadidentityv1pb.DeleteWorkloadIdentityX509RevocationRequest{ + Name: name, + }, + ) + require.NoError(t, err) + } + + // Fetch the initial, empty, CRL + stream, err := revocationsClient.StreamSignedCRL( + ctx, &workloadidentityv1pb.StreamSignedCRLRequest{}, + ) + require.NoError(t, err) + res, err := stream.Recv() + require.NoError(t, err) + checkCRL(t, res.Crl, nil) + + // Create new revocations + createRevocation(t, "ff") + createRevocation(t, "aa") + fakeClock.BlockUntil(2) + t.Log("Advancing fake clock to pass debounce period") + fakeClock.Advance(6 * time.Second) + // The client should now receive a new CRL + res, err = stream.Recv() + require.NoError(t, err) + checkCRL(t, res.Crl, []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(170), + RevocationTime: revokedAt, + }, + { + SerialNumber: big.NewInt(255), + RevocationTime: revokedAt, + }, + }) + + // Add another revocation, delete one revocation + createRevocation(t, "bb") + deleteRevocation(t, "aa") + fakeClock.BlockUntil(2) + t.Log("Advancing fake clock to pass debounce period") + fakeClock.Advance(6 * time.Second) + // The client should now receive a new CRL + res, err = stream.Recv() + require.NoError(t, err) + checkCRL(t, res.Crl, []x509.RevocationListEntry{ + { + SerialNumber: big.NewInt(255), + RevocationTime: revokedAt, + }, + { + SerialNumber: big.NewInt(187), + RevocationTime: revokedAt, + }, + }) + + // Delete all remaining CRL + deleteRevocation(t, "bb") + deleteRevocation(t, "ff") + fakeClock.BlockUntil(2) + t.Log("Advancing fake clock to pass debounce period") + fakeClock.Advance(6 * time.Second) + // The client should now receive a new CRL + res, err = stream.Recv() + require.NoError(t, err) + checkCRL(t, res.Crl, nil) + + // Wait ten minutes to see if the periodic CRL is sent. + t.Log("Advancing fake clock to pass the periodic timer") + fakeClock.Advance(11 * time.Minute) + res, err = stream.Recv() + require.NoError(t, err) + checkCRL(t, res.Crl, nil) +} diff --git a/lib/tbot/config/service_spiffe_svid.go b/lib/tbot/config/service_spiffe_svid.go index c72928608804d..d360603160b14 100644 --- a/lib/tbot/config/service_spiffe_svid.go +++ b/lib/tbot/config/service_spiffe_svid.go @@ -38,6 +38,7 @@ const ( SVIDPEMPath = "svid.pem" SVIDKeyPEMPath = "svid_key.pem" SVIDTrustBundlePEMPath = "svid_bundle.pem" + SVIDCRLPemPath = "svid_crl.pem" ) // SVIDRequestSANs is the configuration for the SANs of a single SVID request. diff --git a/lib/tbot/config/service_workload_identity_x509.go b/lib/tbot/config/service_workload_identity_x509.go index adff17f991eee..5932531f3fdcc 100644 --- a/lib/tbot/config/service_workload_identity_x509.go +++ b/lib/tbot/config/service_workload_identity_x509.go @@ -106,6 +106,9 @@ func (o *WorkloadIdentityX509Service) Describe() []FileDescription { { Name: SVIDTrustBundlePEMPath, }, + { + Name: SVIDCRLPemPath, + }, } return fds } diff --git a/lib/tbot/service_workload_identity_api.go b/lib/tbot/service_workload_identity_api.go index e9159dfe9cf73..f7092c124cf88 100644 --- a/lib/tbot/service_workload_identity_api.go +++ b/lib/tbot/service_workload_identity_api.go @@ -70,6 +70,7 @@ type WorkloadIdentityAPIService struct { log *slog.Logger resolver reversetunnelclient.Resolver trustBundleCache *workloadidentity.TrustBundleCache + crlCache *workloadidentity.CRLCache // client holds the impersonated client for the service client *authclient.Client @@ -293,7 +294,11 @@ func (s *WorkloadIdentityAPIService) FetchX509SVID( bundleSet, err := s.trustBundleCache.GetBundleSet(ctx) if err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "fetching trust bundle set from cache") + } + crlSet, err := s.crlCache.GetCRLSet(ctx) + if err != nil { + return trace.Wrap(err, "fetching CRL set from cache") } var svids []*workloadpb.X509SVID @@ -322,10 +327,16 @@ func (s *WorkloadIdentityAPIService) FetchX509SVID( } } - err = srv.Send(&workloadpb.X509SVIDResponse{ + + resp := &workloadpb.X509SVIDResponse{ Svids: svids, FederatedBundles: bundleSet.EncodedX509Bundles(false), - }) + } + if len(crlSet.LocalCRL) > 0 { + resp.Crl = [][]byte{crlSet.LocalCRL} + } + + err = srv.Send(resp) if err != nil { return trace.Wrap(err) } @@ -350,6 +361,14 @@ func (s *WorkloadIdentityAPIService) FetchX509SVID( } bundleSet = newBundleSet continue + case <-crlSet.Stale(): + newCRLSet, err := s.crlCache.GetCRLSet(ctx) + if err != nil { + return trace.Wrap(err) + } + log.DebugContext(ctx, "CRL set has been updated, distributing to client") + crlSet = newCRLSet + continue case <-time.After(s.botCfg.RenewalInterval): log.DebugContext(ctx, "Renewal interval reached, renewing SVIDs") svids = nil @@ -373,13 +392,21 @@ func (s *WorkloadIdentityAPIService) FetchX509Bundles( for { bundleSet, err := s.trustBundleCache.GetBundleSet(ctx) if err != nil { - return trace.Wrap(err) + return trace.Wrap(err, "fetching trust bundle set from cache") + } + crlSet, err := s.crlCache.GetCRLSet(ctx) + if err != nil { + return trace.Wrap(err, "fetching CRL set from cache") } s.log.InfoContext(ctx, "Sending X.509 trust bundles to workload") - err = srv.Send(&workloadpb.X509BundlesResponse{ + resp := &workloadpb.X509BundlesResponse{ Bundles: bundleSet.EncodedX509Bundles(true), - }) + } + if len(crlSet.LocalCRL) > 0 { + resp.Crl = [][]byte{crlSet.LocalCRL} + } + err = srv.Send(resp) if err != nil { return trace.Wrap(err) } @@ -388,6 +415,9 @@ func (s *WorkloadIdentityAPIService) FetchX509Bundles( case <-ctx.Done(): return nil case <-bundleSet.Stale(): + s.log.DebugContext(ctx, "Trust bundle set has been updated, distributing to client") + case <-crlSet.Stale(): + s.log.DebugContext(ctx, "CRL set has been updated, distributing to client") } } } diff --git a/lib/tbot/service_workload_identity_api_test.go b/lib/tbot/service_workload_identity_api_test.go index 921b70f398037..6208bf82d77f5 100644 --- a/lib/tbot/service_workload_identity_api_test.go +++ b/lib/tbot/service_workload_identity_api_test.go @@ -18,6 +18,7 @@ package tbot import ( "context" + "crypto/x509" "fmt" "net/url" "os" @@ -25,11 +26,14 @@ import ( "sync" "testing" + "github.com/spiffe/go-spiffe/v2/proto/spiffe/workload" "github.com/spiffe/go-spiffe/v2/svid/jwtsvid" "github.com/spiffe/go-spiffe/v2/svid/x509svid" "github.com/spiffe/go-spiffe/v2/workloadapi" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" headerv1 "github.com/gravitational/teleport/api/gen/proto/go/teleport/header/v1" workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" @@ -166,4 +170,27 @@ func TestBotWorkloadIdentityAPI(t *testing.T) { require.NoError(t, err) _, err = jwtsvid.ParseAndValidate(jwtSVID.Marshal(), jwtBundles, []string{"example.com"}) require.NoError(t, err) + + // Check CRL is delivered - we have to manually craft the client for this + // since the current go-spiffe SDK doesn't support this. + // TODO(noah): I'll raise some changes upstream to add CRL field support to + // the go-spiffe SDK, and then we can remove this code. + conn, err := grpc.NewClient( + listenAddr.String(), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + require.NoError(t, err) + spiffeWorkloadAPI := workload.NewSpiffeWorkloadAPIClient(conn) + stream, err := spiffeWorkloadAPI.FetchX509SVID(ctx, &workload.X509SVIDRequest{}) + require.NoError(t, err) + + resp, err := stream.Recv() + require.NoError(t, err) + require.Len(t, resp.Crl, 1) + crl, err := x509.ParseRevocationList(resp.Crl[0]) + require.NoError(t, err) + require.Empty(t, crl.RevokedCertificateEntries) + tb, ok := set.Get(svid.ID.TrustDomain()) + require.True(t, ok) + require.NoError(t, crl.CheckSignatureFrom(tb.X509Authorities()[0])) } diff --git a/lib/tbot/service_workload_identity_x509.go b/lib/tbot/service_workload_identity_x509.go index d54969bb00418..522cf125e59c6 100644 --- a/lib/tbot/service_workload_identity_x509.go +++ b/lib/tbot/service_workload_identity_x509.go @@ -49,6 +49,7 @@ type WorkloadIdentityX509Service struct { // trustBundleCache is the cache of trust bundles. It only needs to be // provided when running in daemon mode. trustBundleCache *workloadidentity.TrustBundleCache + crlCache *workloadidentity.CRLCache } // String returns a human-readable description of the service. @@ -73,9 +74,16 @@ func (s *WorkloadIdentityX509Service) OneShot(ctx context.Context) error { ) if err != nil { return trace.Wrap(err, "fetching trust bundle set") - } - return s.render(ctx, bundleSet, res, privateKey) + crlSet, err := workloadidentity.FetchCRLSet( + ctx, + s.botAuthClient.WorkloadIdentityRevocationServiceClient(), + ) + if err != nil { + return trace.Wrap(err, "fetching CRL set") + } + + return s.render(ctx, bundleSet, res, privateKey, crlSet) } // Run runs the service in daemon mode, periodically generating the output and @@ -85,6 +93,10 @@ func (s *WorkloadIdentityX509Service) Run(ctx context.Context) error { if err != nil { return trace.Wrap(err, "getting trust bundle set") } + crlSet, err := s.crlCache.GetCRLSet(ctx) + if err != nil { + return trace.Wrap(err, "getting CRL set from cache") + } jitter := retryutils.NewJitter() var x509Cred *workloadidentityv1pb.Credential @@ -118,7 +130,7 @@ func (s *WorkloadIdentityX509Service) Run(ctx context.Context) error { if err != nil { return trace.Wrap(err, "getting trust bundle set") } - s.log.InfoContext(ctx, "Trust bundle set has been updated") + s.log.InfoContext(ctx, "Trust bundle set has been updated, will regenerate output") if !newBundleSet.Local.Equal(bundleSet.Local) { // If the local trust domain CA has changed, we need to reissue // the SVID. @@ -126,6 +138,13 @@ func (s *WorkloadIdentityX509Service) Run(ctx context.Context) error { privateKey = nil } bundleSet = newBundleSet + case <-crlSet.Stale(): + newCRLSet, err := s.crlCache.GetCRLSet(ctx) + if err != nil { + return trace.Wrap(err, "getting CRL set from cache") + } + crlSet = newCRLSet + s.log.DebugContext(ctx, "CRL set has been updated, will regenerate output") case <-time.After(s.botCfg.RenewalInterval): s.log.InfoContext(ctx, "Renewal interval reached, renewing SVIDs") x509Cred = nil @@ -142,7 +161,9 @@ func (s *WorkloadIdentityX509Service) Run(ctx context.Context) error { continue } } - if err := s.render(ctx, bundleSet, x509Cred, privateKey); err != nil { + if err := s.render( + ctx, bundleSet, x509Cred, privateKey, crlSet, + ); err != nil { s.log.ErrorContext(ctx, "Failed to render output", "error", err) failures++ continue @@ -232,6 +253,7 @@ func (s *WorkloadIdentityX509Service) render( bundleSet *workloadidentity.BundleSet, x509Cred *workloadidentityv1pb.Credential, privateKey crypto.Signer, + crlSet *workloadidentity.CRLSet, ) error { ctx, span := tracer.Start( ctx, @@ -295,6 +317,13 @@ func (s *WorkloadIdentityX509Service) render( return trace.Wrap(err, "writing svid trust bundle") } + crlBytes := crlSet.Marshal() + if len(crlBytes) > 0 { + if err := s.cfg.Destination.Write(ctx, config.SVIDCRLPemPath, crlBytes); err != nil { + return trace.Wrap(err, "writing CRL") + } + } + s.log.InfoContext( ctx, "Successfully wrote X509 workload identity credential to destination", diff --git a/lib/tbot/service_workload_identity_x509_test.go b/lib/tbot/service_workload_identity_x509_test.go index 00e4317b7a3eb..d6ac5f85e5ce5 100644 --- a/lib/tbot/service_workload_identity_x509_test.go +++ b/lib/tbot/service_workload_identity_x509_test.go @@ -18,6 +18,9 @@ package tbot import ( "context" + "crypto/x509" + "encoding/pem" + "os" "path" "path/filepath" "testing" @@ -84,6 +87,15 @@ func TestBotWorkloadIdentityX509(t *testing.T) { }) require.NoError(t, err) + checkCRL := func(t *testing.T, tmpDir string, bundle *x509bundle.Bundle) { + crlPEM, err := os.ReadFile(filepath.Join(tmpDir, config.SVIDCRLPemPath)) + require.NoError(t, err) + crlBytes, _ := pem.Decode(crlPEM) + crl, err := x509.ParseRevocationList(crlBytes.Bytes) + require.NoError(t, err) + require.NoError(t, crl.CheckSignatureFrom(bundle.X509Authorities()[0])) + } + t.Run("By Name", func(t *testing.T) { tmpDir := t.TempDir() onboarding, _ := makeBot(t, rootClient, "by-name", role.GetName()) @@ -123,6 +135,8 @@ func TestBotWorkloadIdentityX509(t *testing.T) { require.NoError(t, err) _, _, err = x509svid.Verify(svid.Certificates, bundle) require.NoError(t, err) + + checkCRL(t, tmpDir, bundle) }) t.Run("By Labels", func(t *testing.T) { tmpDir := t.TempDir() @@ -165,5 +179,7 @@ func TestBotWorkloadIdentityX509(t *testing.T) { require.NoError(t, err) _, _, err = x509svid.Verify(svid.Certificates, bundle) require.NoError(t, err) + + checkCRL(t, tmpDir, bundle) }) } diff --git a/lib/tbot/tbot.go b/lib/tbot/tbot.go index 5cae9393a83e3..2e5271dab82a0 100644 --- a/lib/tbot/tbot.go +++ b/lib/tbot/tbot.go @@ -305,6 +305,25 @@ func (b *Bot) Run(ctx context.Context) (err error) { services = append(services, trustBundleCache) return trustBundleCache, nil } + var crlCache *workloadidentity.CRLCache + setupCRLCache := func() (*workloadidentity.CRLCache, error) { + if crlCache != nil { + return crlCache, nil + } + + var err error + crlCache, err = workloadidentity.NewCRLCache(workloadidentity.CRLCacheConfig{ + RevocationsClient: b.botIdentitySvc.GetClient().WorkloadIdentityRevocationServiceClient(), + Logger: b.log.With( + teleport.ComponentKey, teleport.Component(componentTBot, "crl-cache"), + ), + }) + if err != nil { + return nil, trace.Wrap(err) + } + services = append(services, crlCache) + return crlCache, nil + } // Append any services configured by the user for _, svcCfg := range b.cfg.Services { @@ -506,6 +525,11 @@ func (b *Bot) Run(ctx context.Context) (err error) { return trace.Wrap(err) } svc.trustBundleCache = tbCache + crlCache, err := setupCRLCache() + if err != nil { + return trace.Wrap(err) + } + svc.crlCache = crlCache } services = append(services, svc) case *config.WorkloadIdentityJWTService: @@ -547,6 +571,10 @@ func (b *Bot) Run(ctx context.Context) (err error) { if err != nil { return trace.Wrap(err) } + crlCache, err := setupCRLCache() + if err != nil { + return trace.Wrap(err) + } svc := &WorkloadIdentityAPIService{ svcIdentity: clientCredential, @@ -554,6 +582,7 @@ func (b *Bot) Run(ctx context.Context) (err error) { cfg: svcCfg, resolver: resolver, trustBundleCache: tbCache, + crlCache: crlCache, } svc.log = b.log.With( teleport.ComponentKey, teleport.Component(componentTBot, "svc", svc.String()), diff --git a/lib/tbot/workloadidentity/crl_cache.go b/lib/tbot/workloadidentity/crl_cache.go new file mode 100644 index 0000000000000..22bb6c0300253 --- /dev/null +++ b/lib/tbot/workloadidentity/crl_cache.go @@ -0,0 +1,247 @@ +/* + * Teleport + * Copyright (C) 2025 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package workloadidentity + +import ( + "bytes" + "context" + "encoding/pem" + "log/slog" + "sync" + "time" + + "github.com/gravitational/trace" + + workloadidentityv1pb "github.com/gravitational/teleport/api/gen/proto/go/teleport/workloadidentity/v1" +) + +// CRLSet is a collection of CRLs. +type CRLSet struct { + // LocalCRL is the CRL related to the local trust domain + LocalCRL []byte + // stale is closed to indicate that this CRLSet has been replaced. + stale chan struct{} +} + +// Clone returns a deep copy of the CRLSet. +func (b *CRLSet) Clone() *CRLSet { + clone := &CRLSet{ + stale: b.stale, + } + if b.LocalCRL != nil { + clone.LocalCRL = make([]byte, len(b.LocalCRL)) + copy(clone.LocalCRL, b.LocalCRL) + } + return clone +} + +// Marshal returns the CRL Set encoded in PEM format. It returns an empty +// byte slice if no CRL is present. +func (b *CRLSet) Marshal() []byte { + return pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: b.LocalCRL, + }) +} + +// Stale returns a channel that will be closed when the CRLSet is stale +// and a new CRLSet is available. +func (b *CRLSet) Stale() <-chan struct{} { + return b.stale +} + +// CRLCache streams CRLs from the revocations service and caches them. It +// provides a mechanism to inform consumers when a new CRL is available. +type CRLCache struct { + revocationsClient workloadidentityv1pb.WorkloadIdentityRevocationServiceClient + logger *slog.Logger + + mu sync.Mutex + crlSet *CRLSet + // initialized will close when the cache is fully initialized. + initialized chan struct{} +} + +// CRLCacheConfig is the configuration for a CRLCache. +type CRLCacheConfig struct { + RevocationsClient workloadidentityv1pb.WorkloadIdentityRevocationServiceClient + Logger *slog.Logger +} + +// NewCRLCache creates a new CRLCache. +func NewCRLCache(cfg CRLCacheConfig) (*CRLCache, error) { + switch { + case cfg.RevocationsClient == nil: + return nil, trace.BadParameter("missing RevocationsClient") + case cfg.Logger == nil: + return nil, trace.BadParameter("missing Logger") + } + return &CRLCache{ + revocationsClient: cfg.RevocationsClient, + logger: cfg.Logger, + initialized: make(chan struct{}), + }, nil +} + +// String returns a string representation of the CRLCache. Implements the +// tbot Service interface and fmt.Stringer interface. +func (m *CRLCache) String() string { + return "crl-cache" +} + +func (m *CRLCache) Run(ctx context.Context) error { + for { + m.logger.InfoContext( + ctx, + "Initializing cache", + ) + if err := m.watch(ctx); err != nil { + if ctx.Err() != nil { + return nil + } + // TODO(noah): DELETE IN V19 once CRL streaming functionality is + // available on all supported versions. + if trace.IsNotImplemented(err) { + m.logger.WarnContext( + ctx, "Server does not support X509 CRL functionality", + ) + // Set empty CRL set so consumers are unblocked. + m.setCRLSet(ctx, &CRLSet{}) + return nil + } + m.logger.ErrorContext( + ctx, + "Cache failed, will attempt to re-initialize after back off", + "error", err, + "backoff", trustBundleInitFailureBackoff, + ) + } + select { + case <-ctx.Done(): + return nil + case <-time.After(trustBundleInitFailureBackoff): + continue + } + } +} + +func (m *CRLCache) watch(ctx context.Context) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := m.revocationsClient.StreamSignedCRL( + ctx, &workloadidentityv1pb.StreamSignedCRLRequest{}, + ) + if err != nil { + return trace.Wrap(err, "opening CRL stream") + } + + for { + res, err := stream.Recv() + if err != nil { + return trace.Wrap(err, "receiving CRL stream data") + } + m.setCRLSet(ctx, &CRLSet{ + LocalCRL: res.Crl, + }) + } +} + +func (m *CRLCache) setCRLSet(ctx context.Context, crlSet *CRLSet) { + m.mu.Lock() + defer m.mu.Unlock() + old := m.crlSet + + // Exit early if the CRL set is the same as the current one. + if old != nil { + if bytes.Equal(old.LocalCRL, crlSet.LocalCRL) { + m.logger.DebugContext(ctx, "Ignoring unchanged CRL set") + return + } + } + + // Clone the CRL set to avoid the caller mutating the state after it has + // been set. + m.crlSet = crlSet.Clone() + m.crlSet.stale = make(chan struct{}) + + if old == nil { + // Indicate that the first CRL set is now available. + close(m.initialized) + } else { + // Indicate that a new CRL set is available. + close(old.stale) + } + m.logger.DebugContext(ctx, "Broadcasting new CRL set to consumers") +} + +func (m *CRLCache) getCRLSet() *CRLSet { + m.mu.Lock() + defer m.mu.Unlock() + if m.crlSet == nil { + return nil + } + // Clone so a receiver cannot mutate the current state without calling + // setCRLSet. + return m.crlSet.Clone() +} + +// GetCRLSet returns the current CRLSet. If the cache is not yet +// initialized, it will block until it is. +func (m *CRLCache) GetCRLSet( + ctx context.Context, +) (*CRLSet, error) { + select { + case <-m.initialized: + return m.getCRLSet(), nil + case <-ctx.Done(): + return nil, ctx.Err() + } +} + +// FetchCRLSet fetches the current CRL set from the revocations service. +// Use this only in the implementation of OneShot methods, and prefer using the +// cache for long-running services. +func FetchCRLSet( + ctx context.Context, + revocationsClient workloadidentityv1pb.WorkloadIdentityRevocationServiceClient, +) (*CRLSet, error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stream, err := revocationsClient.StreamSignedCRL( + ctx, &workloadidentityv1pb.StreamSignedCRLRequest{}, + ) + if err != nil { + return nil, trace.Wrap(err, "streaming CRL") + } + + res, err := stream.Recv() + if err != nil { + if trace.IsNotImplemented(err) { + slog.WarnContext(ctx, "Server does not support X509 CRL functionality, no CRL will be included in the output.") + return &CRLSet{}, nil + } + return nil, trace.Wrap(err, "receiving CRL") + } + + return &CRLSet{ + LocalCRL: res.Crl, + }, nil +} diff --git a/tool/tctl/common/workload_identity_command.go b/tool/tctl/common/workload_identity_command.go index c87b1d95f0334..7bb9b1f686d4d 100644 --- a/tool/tctl/common/workload_identity_command.go +++ b/tool/tctl/common/workload_identity_command.go @@ -18,8 +18,10 @@ package common import ( "context" + "encoding/pem" "fmt" "io" + "log/slog" "math/big" "os" "strings" @@ -50,9 +52,12 @@ type WorkloadIdentityCommand struct { listCmd *kingpin.CmdClause rmCmd *kingpin.CmdClause - revocationsAddCmd *kingpin.CmdClause - revocationsRmCmd *kingpin.CmdClause - revocationsLsCmd *kingpin.CmdClause + revocationsAddCmd *kingpin.CmdClause + revocationsRmCmd *kingpin.CmdClause + revocationsLsCmd *kingpin.CmdClause + revocationsCrlCmd *kingpin.CmdClause + revocationsCRLFollow bool + revocationsCRLOut string revocationType string revocationSerial string @@ -126,6 +131,16 @@ func (c *WorkloadIdentityCommand) Initialize( Default(teleport.Text). EnumVar(&c.format, teleport.Text, teleport.JSON) + c.revocationsCrlCmd = revocationsCmd.Command( + "crl", "Fetch the signed CRL for existing revocations.", + ) + c.revocationsCrlCmd.Flag( + "follow", "Follow the stream of CRL updates.", + ).BoolVar(&c.revocationsCRLFollow) + c.revocationsCrlCmd.Flag( + "out", "Path to write the CRL as a file to. If unspecified, STDOUT will be used.", + ).StringVar(&c.revocationsCRLOut) + if c.stdout == nil { c.stdout = os.Stdout } @@ -150,6 +165,8 @@ func (c *WorkloadIdentityCommand) TryRun( commandFunc = c.ListRevocations case c.revocationsRmCmd.FullCommand(): commandFunc = c.DeleteRevocation + case c.revocationsCrlCmd.FullCommand(): + commandFunc = c.StreamCRL default: return false, nil } @@ -378,3 +395,55 @@ func (c *WorkloadIdentityCommand) ListRevocations( } return nil } + +func (c *WorkloadIdentityCommand) StreamCRL( + ctx context.Context, client *authclient.Client, +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + revocationsClient := client.WorkloadIdentityRevocationServiceClient() + + req := &workloadidentityv1pb.StreamSignedCRLRequest{} + stream, err := revocationsClient.StreamSignedCRL(ctx, req) + if err != nil { + return trace.Wrap(err) + } + + write := func(data []byte) error { + _, err := c.stdout.Write(data) + return trace.Wrap(err) + } + if c.revocationsCRLOut != "" { + write = func(data []byte) error { + err := os.WriteFile(c.revocationsCRLOut, data, 0644) + if err != nil { + return trace.Wrap(err) + } + slog.InfoContext(ctx, "Successfully wrote updated CRL", "path", c.revocationsCRLOut) + return nil + } + } + + for { + res, err := stream.Recv() + if err != nil { + if trace.IsNotImplemented(err) { + slog.ErrorContext(ctx, "Server does not support X509 CRL functionality") + } + return trace.Wrap(err) + } + slog.InfoContext(ctx, "Received CRL from server") + pemData := pem.EncodeToMemory(&pem.Block{ + Type: "X509 CRL", + Bytes: res.Crl, + }) + if err := write(pemData); err != nil { + return trace.Wrap(err, "writing CRL pem") + } + + // If --follow has not been specified, exit. + if !c.revocationsCRLFollow { + return nil + } + } +}