Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/Core/Vault/Repositories/ISecurityTaskRepository.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ public interface ISecurityTaskRepository : IRepository<SecurityTask, Guid>
/// <param name="organizationId">The id of the organization</param>
/// <returns>A collection of security task metrics</returns>
Task<SecurityTaskMetrics> GetTaskMetricsAsync(Guid organizationId);

/// <summary>
/// Marks all tasks associated with the respective ciphers as complete.
/// </summary>
/// <param name="cipherIds">Collection of cipher IDs</param>
Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds);
}
6 changes: 6 additions & 0 deletions src/Core/Vault/Services/Implementations/CipherService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class CipherService : ICipherService
private readonly IOrganizationRepository _organizationRepository;
private readonly IOrganizationUserRepository _organizationUserRepository;
private readonly ICollectionCipherRepository _collectionCipherRepository;
private readonly ISecurityTaskRepository _securityTaskRepository;
private readonly IPushNotificationService _pushService;
private readonly IAttachmentStorageService _attachmentStorageService;
private readonly IEventService _eventService;
Expand All @@ -53,6 +54,7 @@ public CipherService(
IOrganizationRepository organizationRepository,
IOrganizationUserRepository organizationUserRepository,
ICollectionCipherRepository collectionCipherRepository,
ISecurityTaskRepository securityTaskRepository,
IPushNotificationService pushService,
IAttachmentStorageService attachmentStorageService,
IEventService eventService,
Expand All @@ -71,6 +73,7 @@ public CipherService(
_organizationRepository = organizationRepository;
_organizationUserRepository = organizationUserRepository;
_collectionCipherRepository = collectionCipherRepository;
_securityTaskRepository = securityTaskRepository;
_pushService = pushService;
_attachmentStorageService = attachmentStorageService;
_eventService = eventService;
Expand Down Expand Up @@ -724,6 +727,7 @@ public async Task SoftDeleteAsync(CipherDetails cipherDetails, Guid deletingUser
cipherDetails.ArchivedDate = null;
}

await _securityTaskRepository.MarkAsCompleteByCipherIds([cipherDetails.Id]);
await _cipherRepository.UpsertAsync(cipherDetails);
await _eventService.LogCipherEventAsync(cipherDetails, EventType.Cipher_SoftDeleted);

Expand All @@ -750,6 +754,8 @@ public async Task SoftDeleteManyAsync(IEnumerable<Guid> cipherIds, Guid deleting
await _cipherRepository.SoftDeleteAsync(deletingCiphers.Select(c => c.Id), deletingUserId);
}

await _securityTaskRepository.MarkAsCompleteByCipherIds(deletingCiphers.Select(c => c.Id));

var events = deletingCiphers.Select(c =>
new Tuple<Cipher, EventType, DateTime?>(c, EventType.Cipher_SoftDeleted, null));
foreach (var eventsBatch in events.Chunk(100))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,19 @@ await connection.ExecuteAsync(

return tasksList;
}

/// <inheritdoc />
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
{
if (!cipherIds.Any())
{
return;
}

await using var connection = new SqlConnection(ConnectionString);
await connection.ExecuteAsync(
$"[{Schema}].[SecurityTask_MarkCompleteByCipherIds]",
new { CipherIds = cipherIds.ToGuidIdArrayTVP() },
commandType: CommandType.StoredProcedure);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,24 @@ join o in dbContext.Organizations on st.OrganizationId equals o.Id

return metrics ?? new Core.Vault.Entities.SecurityTaskMetrics(0, 0);
}

/// <inheritdoc />
public async Task MarkAsCompleteByCipherIds(IEnumerable<Guid> cipherIds)
{
if (!cipherIds.Any())
{
return;
}

using var scope = ServiceScopeFactory.CreateScope();
var dbContext = GetDatabaseContext(scope);

var cipherIdsList = cipherIds.ToList();

await dbContext.SecurityTasks
.Where(st => st.CipherId.HasValue && cipherIdsList.Contains(st.CipherId.Value) && st.Status != SecurityTaskStatus.Completed)
.ExecuteUpdateAsync(st => st
.SetProperty(s => s.Status, SecurityTaskStatus.Completed)
.SetProperty(s => s.RevisionDate, DateTime.UtcNow));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CREATE PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
@CipherIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON

UPDATE
[dbo].[SecurityTask]
SET
[Status] = 1, -- completed
[RevisionDate] = SYSUTCDATETIME()
WHERE
[CipherId] IN (SELECT [Id] FROM @CipherIds)
AND [Status] <> 1 -- Not already completed
END
57 changes: 57 additions & 0 deletions test/Core.Test/Vault/Services/CipherServiceTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2286,6 +2286,63 @@ await sutProvider.GetDependency<IPushNotificationService>()
.PushSyncCiphersAsync(deletingUserId);
}

[Theory]
[BitAutoData]
public async Task SoftDeleteAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, CipherDetails cipherDetails, SutProvider<CipherService> sutProvider)
{
cipherDetails.UserId = deletingUserId;
cipherDetails.OrganizationId = null;
cipherDetails.DeletedDate = null;

sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});

await sutProvider.Sut.SoftDeleteAsync(cipherDetails, deletingUserId);

await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == 1 && ids.First() == cipherDetails.Id));
}

[Theory]
[BitAutoData]
public async Task SoftDeleteManyAsync_CallsMarkAsCompleteByCipherIds(
Guid deletingUserId, List<CipherDetails> ciphers, SutProvider<CipherService> sutProvider)
{
var cipherIds = ciphers.Select(c => c.Id).ToArray();

foreach (var cipher in ciphers)
{
cipher.UserId = deletingUserId;
cipher.OrganizationId = null;
cipher.Edit = true;
cipher.DeletedDate = null;
}

sutProvider.GetDependency<IUserService>()
.GetUserByIdAsync(deletingUserId)
.Returns(new User
{
Id = deletingUserId,
});
sutProvider.GetDependency<ICipherRepository>()
.GetManyByUserIdAsync(deletingUserId)
.Returns(ciphers);

await sutProvider.Sut.SoftDeleteManyAsync(cipherIds, deletingUserId, null, false);

await sutProvider.GetDependency<ISecurityTaskRepository>()
.Received(1)
.MarkAsCompleteByCipherIds(Arg.Is<IEnumerable<Guid>>(ids =>
ids.Count() == cipherIds.Length && ids.All(id => cipherIds.Contains(id))));
}

private async Task AssertNoActionsAsync(SutProvider<CipherService> sutProvider)
{
await sutProvider.GetDependency<ICipherRepository>().DidNotReceiveWithAnyArgs().GetManyOrganizationDetailsByOrganizationIdAsync(default);
Expand Down
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

πŸ§ͺ πŸŽ‰ Thanks!

Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,110 @@ public async Task GetZeroTaskMetricsAsync(
Assert.Equal(0, metrics.CompletedTasks);
Assert.Equal(0, metrics.TotalTasks);
}

[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_MarksPendingTasksAsCompleted(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "[email protected]"
});

var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});

var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});

var task1 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});

var task2 = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});

await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id, cipher2.Id]);

var updatedTask1 = await securityTaskRepository.GetByIdAsync(task1.Id);
var updatedTask2 = await securityTaskRepository.GetByIdAsync(task2.Id);

Assert.Equal(SecurityTaskStatus.Completed, updatedTask1.Status);
Assert.Equal(SecurityTaskStatus.Completed, updatedTask2.Status);
}

[DatabaseTheory, DatabaseData]
public async Task MarkAsCompleteByCipherIds_OnlyUpdatesSpecifiedCiphers(
IOrganizationRepository organizationRepository,
ICipherRepository cipherRepository,
ISecurityTaskRepository securityTaskRepository)
{
var organization = await organizationRepository.CreateAsync(new Organization
{
Name = "Test Org",
PlanType = PlanType.EnterpriseAnnually,
Plan = "Test Plan",
BillingEmail = "[email protected]"
});

var cipher1 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});

var cipher2 = await cipherRepository.CreateAsync(new Cipher
{
Type = CipherType.Login,
OrganizationId = organization.Id,
Data = "",
});

var taskToUpdate = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher1.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});

var taskToKeep = await securityTaskRepository.CreateAsync(new SecurityTask
{
OrganizationId = organization.Id,
CipherId = cipher2.Id,
Status = SecurityTaskStatus.Pending,
Type = SecurityTaskType.UpdateAtRiskCredential,
});

await securityTaskRepository.MarkAsCompleteByCipherIds([cipher1.Id]);

var updatedTask = await securityTaskRepository.GetByIdAsync(taskToUpdate.Id);
var unchangedTask = await securityTaskRepository.GetByIdAsync(taskToKeep.Id);

Assert.Equal(SecurityTaskStatus.Completed, updatedTask.Status);
Assert.Equal(SecurityTaskStatus.Pending, unchangedTask.Status);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
CREATE OR ALTER PROCEDURE [dbo].[SecurityTask_MarkCompleteByCipherIds]
@CipherIds AS [dbo].[GuidIdArray] READONLY
AS
BEGIN
SET NOCOUNT ON

UPDATE
[dbo].[SecurityTask]
SET
[Status] = 1, -- Completed
[RevisionDate] = SYSUTCDATETIME()
WHERE
[CipherId] IN (SELECT [Id] FROM @CipherIds)
AND [Status] <> 1 -- Not already completed
END
Loading