diff --git a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs index 4b88f1c0e840..0be3bbd5456c 100644 --- a/src/Core/Vault/Repositories/ISecurityTaskRepository.cs +++ b/src/Core/Vault/Repositories/ISecurityTaskRepository.cs @@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository /// The id of the organization /// A collection of security task metrics Task GetTaskMetricsAsync(Guid organizationId); + + /// + /// Marks all tasks associated with the respective ciphers as complete. + /// + /// Collection of cipher IDs + Task MarkAsCompleteByCipherIds(IEnumerable cipherIds); } diff --git a/src/Core/Vault/Services/Implementations/CipherService.cs b/src/Core/Vault/Services/Implementations/CipherService.cs index db458a523d0e..4e980f66b6ce 100644 --- a/src/Core/Vault/Services/Implementations/CipherService.cs +++ b/src/Core/Vault/Services/Implementations/CipherService.cs @@ -33,6 +33,7 @@ public class CipherService : ICipherService private readonly IOrganizationRepository _organizationRepository; private readonly IOrganizationUserRepository _organizationUserRepository; private readonly ICollectionCipherRepository _collectionCipherRepository; + private readonly ISecurityTaskRepository _securityTaskRepository; private readonly IPushNotificationService _pushService; private readonly IAttachmentStorageService _attachmentStorageService; private readonly IEventService _eventService; @@ -53,6 +54,7 @@ public CipherService( IOrganizationRepository organizationRepository, IOrganizationUserRepository organizationUserRepository, ICollectionCipherRepository collectionCipherRepository, + ISecurityTaskRepository securityTaskRepository, IPushNotificationService pushService, IAttachmentStorageService attachmentStorageService, IEventService eventService, @@ -71,6 +73,7 @@ public CipherService( _organizationRepository = organizationRepository; _organizationUserRepository = organizationUserRepository; _collectionCipherRepository = collectionCipherRepository; + _securityTaskRepository = securityTaskRepository; _pushService = pushService; _attachmentStorageService = attachmentStorageService; _eventService = eventService; @@ -724,6 +727,7 @@ public async Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUser cipherDetails.ArchivedDate = null; } + await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]); await _cipherRepository.UpsertAsync(cipherDetails); await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted); @@ -750,6 +754,8 @@ public async Task SoftDeleteManyAsync(IEnumerable cipherIds, Guid deleting await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId); } + await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id)); + var events = deletingCiphers.Select(c => new Tuple(c, EventType.Cipher_SoftDeleted, null)); foreach (var eventsBatch in events.Chunk(100)) diff --git a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs index 292e99d6ad73..869321f2800d 100644 --- a/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs +++ b/src/Infrastructure.Dapper/Vault/Repositories/SecurityTaskRepository.cs @@ -85,4 +85,19 @@ await connection.ExecuteAsync( return tasksList; } + + /// + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + await using var connection = new SqlConnection(ConnectionString); + await connection.ExecuteAsync( + $"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]", + new { CipherIds = cipherIds.ToGuidIdArrayTVP() }, + commandType: CommandType.StoredProcedure); + } } diff --git a/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs b/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs index d4f9424d40a4..9967f18a3e2c 100644 --- a/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs +++ b/src/Infrastructure.EntityFramework/Vault/Repositories/SecurityTaskRepository.cs @@ -96,4 +96,24 @@ join o in dbContext.Organizations on st.OrganizationId equals o.Id return metrics ?? new Core.Vault.Entities.SecurityTaskMetrics(0, 0); } + + /// + public async Task MarkAsCompleteByCipherIds(IEnumerable cipherIds) + { + if (!cipherIds.Any()) + { + return; + } + + using var scope = ServiceScopeFactory.CreateScope(); + var dbContext = GetDatabaseContext(scope); + + var cipherIdsList = cipherIds.ToList(); + + await dbContext.SecurityTasks + .Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed) + .ExecuteUpdateAsync(st => st + .SetProperty(s => s.Status, SecurityTaskStatus.Completed) + .SetProperty(s => s.RevisionDate, DateTime.UtcNow)); + } } diff --git a/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql new file mode 100644 index 000000000000..8e00d06e4365 --- /dev/null +++ b/src/Sql/dbo/Vault/Stored Procedures/SecurityTask/SecurityTask_MarkCompleteByCipherIds.sql @@ -0,0 +1,15 @@ +CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds] + @CipherIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[SecurityTask] + SET + [Status] = 1, -- completed + [RevisionDate] = SYSUTCDATETIME() + WHERE + [CipherId] IN (SELECT [Id] FROM @CipherIds) + AND [Status] <> 1 -- Not already completed +END diff --git a/test/Core.Test/Vault/Services/CipherServiceTests.cs b/test/Core.Test/Vault/Services/CipherServiceTests.cs index 95391f1f44df..fb53c41bad6c 100644 --- a/test/Core.Test/Vault/Services/CipherServiceTests.cs +++ b/test/Core.Test/Vault/Services/CipherServiceTests.cs @@ -2286,6 +2286,63 @@ await sutProvider.GetDependency() .PushSyncCiphersAsync(deletingUserId); } + [Theory] + [BitAutoData] + public async Task SoftDeleteAsync_CallsMarkAsCompleteByCipherIds( + Guid deletingUserId, CipherDetails cipherDetails, SutProvider sutProvider) + { + cipherDetails.UserId = deletingUserId; + cipherDetails.OrganizationId = null; + cipherDetails.DeletedDate = null; + + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(new User + { + Id = deletingUserId, + }); + + await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId); + + await sutProvider.GetDependency() + .Received(1) + .MarkAsCompleteByCipherIds(Arg.Is>(ids => + ids.Count() == 1 && ids.First() == cipherDetails.Id)); + } + + [Theory] + [BitAutoData] + public async Task SoftDeleteManyAsync_CallsMarkAsCompleteByCipherIds( + Guid deletingUserId, List ciphers, SutProvider sutProvider) + { + var cipherIds = ciphers.Select(c => c.Id).ToArray(); + + foreach (var cipher in ciphers) + { + cipher.UserId = deletingUserId; + cipher.OrganizationId = null; + cipher.Edit = true; + cipher.DeletedDate = null; + } + + sutProvider.GetDependency() + .GetUserByIdAsync(deletingUserId) + .Returns(new User + { + Id = deletingUserId, + }); + sutProvider.GetDependency() + .GetManyByUserIdAsync(deletingUserId) + .Returns(ciphers); + + await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false); + + await sutProvider.GetDependency() + .Received(1) + .MarkAsCompleteByCipherIds(Arg.Is>(ids => + ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id)))); + } + private async Task AssertNoActionsAsync(SutProvider sutProvider) { await sutProvider.GetDependency().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default); diff --git a/test/Infrastructure.IntegrationTest/Vault/Repositories/SecurityTaskRepositoryTests.cs b/test/Infrastructure.IntegrationTest/Vault/Repositories/SecurityTaskRepositoryTests.cs index f17950c04de3..68c1be69f673 100644 --- a/test/Infrastructure.IntegrationTest/Vault/Repositories/SecurityTaskRepositoryTests.cs +++ b/test/Infrastructure.IntegrationTest/Vault/Repositories/SecurityTaskRepositoryTests.cs @@ -345,4 +345,110 @@ public async Task GetZeroTaskMetricsAsync( Assert.Equal(0, metrics.CompletedTasks); Assert.Equal(0, metrics.TotalTasks); } + + [DatabaseTheory, DatabaseData] + public async Task MarkAsCompleteByCipherIds_MarksPendingTasksAsCompleted( + IOrganizationRepository organizationRepository, + ICipherRepository cipherRepository, + ISecurityTaskRepository securityTaskRepository) + { + var organization = await organizationRepository.CreateAsync(new Organization + { + Name = "Test Org", + PlanType = PlanType.EnterpriseAnnually, + Plan = "Test Plan", + BillingEmail = "billing@email.com" + }); + + var cipher1 = await cipherRepository.CreateAsync(new Cipher + { + Type = CipherType.Login, + OrganizationId = organization.Id, + Data = "", + }); + + var cipher2 = await cipherRepository.CreateAsync(new Cipher + { + Type = CipherType.Login, + OrganizationId = organization.Id, + Data = "", + }); + + var task1 = await securityTaskRepository.CreateAsync(new SecurityTask + { + OrganizationId = organization.Id, + CipherId = cipher1.Id, + Status = SecurityTaskStatus.Pending, + Type = SecurityTaskType.UpdateAtRiskCredential, + }); + + var task2 = await securityTaskRepository.CreateAsync(new SecurityTask + { + OrganizationId = organization.Id, + CipherId = cipher2.Id, + Status = SecurityTaskStatus.Pending, + Type = SecurityTaskType.UpdateAtRiskCredential, + }); + + await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id, cipher2.Id]); + + var updatedTask1 = await securityTaskRepository.GetByIdAsync(task1.Id); + var updatedTask2 = await securityTaskRepository.GetByIdAsync(task2.Id); + + Assert.Equal(SecurityTaskStatus.Completed, updatedTask1.Status); + Assert.Equal(SecurityTaskStatus.Completed, updatedTask2.Status); + } + + [DatabaseTheory, DatabaseData] + public async Task MarkAsCompleteByCipherIds_OnlyUpdatesSpecifiedCiphers( + IOrganizationRepository organizationRepository, + ICipherRepository cipherRepository, + ISecurityTaskRepository securityTaskRepository) + { + var organization = await organizationRepository.CreateAsync(new Organization + { + Name = "Test Org", + PlanType = PlanType.EnterpriseAnnually, + Plan = "Test Plan", + BillingEmail = "billing@email.com" + }); + + var cipher1 = await cipherRepository.CreateAsync(new Cipher + { + Type = CipherType.Login, + OrganizationId = organization.Id, + Data = "", + }); + + var cipher2 = await cipherRepository.CreateAsync(new Cipher + { + Type = CipherType.Login, + OrganizationId = organization.Id, + Data = "", + }); + + var taskToUpdate = await securityTaskRepository.CreateAsync(new SecurityTask + { + OrganizationId = organization.Id, + CipherId = cipher1.Id, + Status = SecurityTaskStatus.Pending, + Type = SecurityTaskType.UpdateAtRiskCredential, + }); + + var taskToKeep = await securityTaskRepository.CreateAsync(new SecurityTask + { + OrganizationId = organization.Id, + CipherId = cipher2.Id, + Status = SecurityTaskStatus.Pending, + Type = SecurityTaskType.UpdateAtRiskCredential, + }); + + await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id]); + + var updatedTask = await securityTaskRepository.GetByIdAsync(taskToUpdate.Id); + var unchangedTask = await securityTaskRepository.GetByIdAsync(taskToKeep.Id); + + Assert.Equal(SecurityTaskStatus.Completed, updatedTask.Status); + Assert.Equal(SecurityTaskStatus.Pending, unchangedTask.Status); + } } diff --git a/util/Migrator/DbScripts/2025-10-23_00_CompleteSecurityTaskByCipherIds.sql b/util/Migrator/DbScripts/2025-10-23_00_CompleteSecurityTaskByCipherIds.sql new file mode 100644 index 000000000000..e465b8470aaa --- /dev/null +++ b/util/Migrator/DbScripts/2025-10-23_00_CompleteSecurityTaskByCipherIds.sql @@ -0,0 +1,15 @@ +CREATE OR ALTER PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds] + @CipherIds AS [dbo].[GuidIdArray] READONLY +AS +BEGIN + SET NOCOUNT ON + + UPDATE + [dbo].[SecurityTask] + SET + [Status] = 1, -- Completed + [RevisionDate] = SYSUTCDATETIME() + WHERE + [CipherId] IN (SELECT [Id] FROM @CipherIds) + AND [Status] <> 1 -- Not already completed +END