Skip to content
Merged
13 changes: 7 additions & 6 deletions src/Microsoft.ML.Sweeper/AsyncSweeper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using System.Threading.Tasks.Dataflow;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Internal.Utilities;
Expand Down Expand Up @@ -168,7 +168,7 @@ public sealed class Options
private readonly object _lock;
private readonly CancellationTokenSource _cts;

private readonly BufferBlock<ParameterSetWithId> _paramQueue;
private readonly Channel<ParameterSetWithId> _paramChannel;
private readonly int _relaxation;
private readonly ISweeper _baseSweeper;
private readonly IHost _host;
Expand Down Expand Up @@ -208,7 +208,8 @@ public DeterministicSweeperAsync(IHostEnvironment env, Options options)
_lock = new object();
_results = new List<IRunResult>();
_nullRuns = new HashSet<int>();
_paramQueue = new BufferBlock<ParameterSetWithId>();
_paramChannel = Channel.CreateUnbounded<ParameterSetWithId>(
new UnboundedChannelOptions { SingleWriter = true });

PrepareNextBatch(null);
}
Expand All @@ -220,12 +221,12 @@ private void PrepareNextBatch(IEnumerable<IRunResult> results)
if (Utils.Size(paramSets) == 0)
{
// Mark the queue as completed.
_paramQueue.Complete();
_paramChannel.Writer.Complete();
return;
}
// Assign an id to each ParameterSet and enque it.
foreach (var paramSet in paramSets)
_paramQueue.Post(new ParameterSetWithId(_numGenerated++, paramSet));
_paramChannel.Writer.TryWrite(new ParameterSetWithId(_numGenerated++, paramSet));
EnsureResultsSize();
}

Expand Down Expand Up @@ -278,7 +279,7 @@ public async Task<ParameterSetWithId> ProposeAsync()
return null;
try
{
return await _paramQueue.ReceiveAsync(_cts.Token);
return await _paramChannel.Reader.ReadAsync(_cts.Token);
}
catch (InvalidOperationException)
{
Expand Down
4 changes: 4 additions & 0 deletions src/Microsoft.ML.Sweeper/Microsoft.ML.Sweeper.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="System.Threading.Channels" Version="4.7.1" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.CpuMath\Microsoft.ML.CpuMath.csproj" />
Expand Down