diff --git a/src/MalwareSampleExchange.Console/Database/MongoMetadataHandler.cs b/src/MalwareSampleExchange.Console/Database/MongoMetadataHandler.cs index bbea6bd..e3627be 100644 --- a/src/MalwareSampleExchange.Console/Database/MongoMetadataHandler.cs +++ b/src/MalwareSampleExchange.Console/Database/MongoMetadataHandler.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Hosting; @@ -19,7 +20,7 @@ public interface ISampleMetadataHandler /// /// /// - Task> GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, CancellationToken token = default); + IAsyncEnumerable GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, CancellationToken token = default); Task InsertSampleAsync(RequestExportSample sample, CancellationToken token = default); } @@ -35,18 +36,25 @@ public MongoMetadataHandler(IOptions options) _mongoClient = new MongoClient(options.Value.ConnectionString); } - public async Task> GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, CancellationToken token = default) + public async IAsyncEnumerable GetSamplesAsync(DateTime start, DateTime? end, string? sampleSet, [EnumeratorCancellation] CancellationToken token = default) { var mongoDatabase = _mongoClient.GetDatabase(_options.DatabaseName); var sampleCollection = mongoDatabase.GetCollection(_options.CollectionName); var list = end == null ? await sampleCollection - .FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start, cancellationToken: token) + .FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start && sample.DoNotUseBefore <= DateTime.Now, cancellationToken: token) : await sampleCollection - .FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start && sample.Imported <= end, cancellationToken: token); - return list.ToList(cancellationToken: token); - } + .FindAsync(sample => sample.SampleSet == sampleSet && sample.Imported >= start && sample.Imported <= end && sample.DoNotUseBefore <= DateTime.Now, cancellationToken: token); + while (await list.MoveNextAsync(token)) + { + foreach (var current in list.Current) + { + yield return current; + } + } + } + public async Task InsertSampleAsync(RequestExportSample sample, CancellationToken token = default) { var mongoDatabase = _mongoClient.GetDatabase(_options.DatabaseName); diff --git a/src/MalwareSampleExchange.Console/ListRequester/ListRequester.cs b/src/MalwareSampleExchange.Console/ListRequester/ListRequester.cs index ef3b168..47380ac 100644 --- a/src/MalwareSampleExchange.Console/ListRequester/ListRequester.cs +++ b/src/MalwareSampleExchange.Console/ListRequester/ListRequester.cs @@ -43,10 +43,9 @@ public async IAsyncEnumerable RequestList(string username, DateTime start { var includeFamilyName = _partnerProvider.Partners.Single(_ => _.Name == username).IncludeFamilyName; var sampleSet = _partnerProvider.Partners.SingleOrDefault(_ => _.Name == username)?.Sampleset; + var samples = _sampleMetadataHandler.GetSamplesAsync(start, end, sampleSet, token); - var samples = await _sampleMetadataHandler.GetSamplesAsync(start, end, sampleSet, token); - - foreach (var sample in samples.Where(sample => sample.DoNotUseBefore <= DateTime.Now)) + await foreach (var sample in samples) { var fileSize = sample.FileSize == 0 ? await _sampleStorageHandler.GetFileSizeAsync(sample.Sha256, token)