Skip to content

Commit

Permalink
Add cancellation capability for commands. Also change Command.Wait() …
Browse files Browse the repository at this point in the history
…and Command.Result to not throw AggregateException. Fix #18
  • Loading branch information
madelson committed Jul 1, 2017
1 parent 3520b47 commit e6219fc
Show file tree
Hide file tree
Showing 8 changed files with 240 additions and 51 deletions.
97 changes: 95 additions & 2 deletions MedallionShell.Tests/GeneralTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,107 @@ public void TestTimeout()
Assert.IsInstanceOfType(ex.InnerException, typeof(TimeoutException));
}

[TestMethod]
public void TestZeroTimeout()
{
var willTimeout = Command.Run("SampleCommand", new object[] { "sleep", 1000000 }, o => o.Timeout(TimeSpan.Zero));
var ex = UnitTestHelpers.AssertThrows<AggregateException>(() => willTimeout.Task.Wait());
Assert.IsInstanceOfType(ex.InnerException, typeof(TimeoutException));
}

[TestMethod]
public void TestCancellationAlreadyCanceled()
{
using (var alreadyCanceled = new CancellationTokenSource(millisecondsDelay: 0))
{
var command = Command.Run("SampleCommand", new object[] { "sleep", 1000000 }, o => o.CancellationToken(alreadyCanceled.Token));
UnitTestHelpers.AssertThrows<TaskCanceledException>(() => command.Wait());
UnitTestHelpers.AssertThrows<TaskCanceledException>(() => command.Result.ToString());
command.Task.Status.ShouldEqual(TaskStatus.Canceled);
UnitTestHelpers.AssertDoesNotThrow(() => command.ProcessId.ToString(), "still executes a command and gets a process ID");
}
}

[TestMethod]
public void TestCancellationNotCanceled()
{
using (var notCanceled = new CancellationTokenSource())
{
var command = Command.Run("SampleCommand", new object[] { "sleep", 1000000 }, o => o.CancellationToken(notCanceled.Token));
command.Task.Wait(50).ShouldEqual(false);
command.Kill();
command.Task.Wait(1000).ShouldEqual(true);
command.Result.Success.ShouldEqual(false);
}
}

[TestMethod]
public void TestCancellationCanceledPartway()
{
using (var cancellationTokenSource = new CancellationTokenSource())
{
var results = new SynchronizedCollection<string>();
var command = Command.Run("SampleCommand", new object[] { "echo", "--per-char" }, o => o.CancellationToken(cancellationTokenSource.Token)) > results;
command.StandardInput.WriteLine("hello");
var timeout = Task.Delay(TimeSpan.FromSeconds(10));
while (results.Count == 0 && !timeout.IsCompleted) ;
results.Count.ShouldEqual(1);
cancellationTokenSource.Cancel();
var aggregateException = UnitTestHelpers.AssertThrows<AggregateException>(() => command.Task.Wait(1000));
UnitTestHelpers.AssertIsInstanceOf<TaskCanceledException>(aggregateException.GetBaseException());
CollectionAssert.AreEqual(results, new[] { "hello" });
}
}

[TestMethod]
public void TestCancellationCanceledAfterCompletion()
{
using (var cancellationTokenSource = new CancellationTokenSource())
{
var results = new List<string>();
var command = Command.Run("SampleCommand", new object[] { "echo" }, o => o.CancellationToken(cancellationTokenSource.Token)) > results;
command.StandardInput.WriteLine("hello");
command.StandardInput.Close();
command.Task.Wait(1000).ShouldEqual(true);
cancellationTokenSource.Cancel();
command.Result.Success.ShouldEqual(true);
}
}

[TestMethod]
public void TestCancellationWithTimeoutTimeoutWins()
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(5));
var command = Command.Run(
"SampleCommand",
new object[] { "sleep", 1000000 },
o => o.CancellationToken(cancellationTokenSource.Token)
.Timeout(TimeSpan.FromMilliseconds(50))
);
UnitTestHelpers.AssertThrows<TimeoutException>(() => command.Wait());
}

[TestMethod]
public void TestCancellationWithTimeoutCancellationWins()
{
var cancellationTokenSource = new CancellationTokenSource(TimeSpan.FromMilliseconds(50));
var command = Command.Run(
"SampleCommand",
new object[] { "sleep", 1000000 },
o => o.CancellationToken(cancellationTokenSource.Token)
.Timeout(TimeSpan.FromSeconds(5))
);
UnitTestHelpers.AssertThrows<TaskCanceledException>(() => command.Wait());
}

[TestMethod]
public void TestErrorHandling()
{
var command = Command.Run("SampleCommand", "echo") < "abc" > new char[0];
UnitTestHelpers.AssertThrows<AggregateException>(() => command.Wait());
UnitTestHelpers.AssertThrows<NotSupportedException>(() => command.Wait());

var command2 = Command.Run("SampleCommand", "echo") < this.ErrorLines();
UnitTestHelpers.AssertThrows<AggregateException>(() => command.Wait());
UnitTestHelpers.AssertThrows<InvalidOperationException>(() => command2.Wait());
}

[TestMethod]
Expand Down
1 change: 1 addition & 0 deletions MedallionShell.Tests/MedallionShell.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
<Reference Include="System.Management" />
<Reference Include="System.Net.Http" />
<Reference Include="System.Numerics" />
<Reference Include="System.ServiceModel" />
<Reference Include="System.Xml" />
<Reference Include="System.Xml.Linq" />
</ItemGroup>
Expand Down
9 changes: 9 additions & 0 deletions MedallionShell.Tests/UnitTestHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ public static TException AssertThrows<TException>(Action action, string message
throw new InvalidOperationException("Should never get here");
}

public static void AssertDoesNotThrow(Action action, string message = null)
{
try { action(); }
catch (Exception ex)
{
Assert.Fail($"Expected: no failure; was: '{ex}'{(message != null ? message + ": " : string.Empty)}");
}
}

public static void AssertIsInstanceOf<T>(object value, string message = null)
{
Assert.IsInstanceOfType(value, typeof(T), message);
Expand Down
71 changes: 71 additions & 0 deletions MedallionShell/CancellationOrTimeout.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading;
using System.Threading.Tasks;

namespace Medallion.Shell
{
/// <summary>
/// Provides a <see cref="Task"/> which combines the logic of watching for cancellation or a timeout
/// </summary>
internal sealed class CancellationOrTimeout : IDisposable
{
private List<IDisposable> resources;

private CancellationOrTimeout(Task task, List<IDisposable> resources)
{
this.Task = task;
this.resources = resources;
}

public static CancellationOrTimeout TryCreate(CancellationToken cancellationToken, TimeSpan timeout)
{
var hasCancellation = cancellationToken.CanBeCanceled;
var hasTimeout = timeout != Timeout.InfiniteTimeSpan;

if (!hasCancellation && !hasTimeout)
{
// originally, I designed this to return a static task which never completes. However, this can cause
// memory leaks from the continuations that build up on the task
return null;
}

var resources = new List<IDisposable>();
var taskBuilder = new TaskCompletionSource<bool>();

if (hasCancellation)
{
resources.Add(cancellationToken.Register(
state => ((TaskCompletionSource<bool>)state).TrySetCanceled(),
state: taskBuilder
));
}

if (hasTimeout)
{
var timeoutSource = new CancellationTokenSource(timeout);
resources.Add(timeoutSource);

resources.Add(timeoutSource.Token.Register(
state =>
{
var tupleState = (Tuple<TaskCompletionSource<bool>, TimeSpan>)state;
tupleState.Item1.TrySetException(new TimeoutException("Process killed after exceeding timeout of " + tupleState.Item2));
},
state: Tuple.Create(taskBuilder, timeout)
));
}

return new CancellationOrTimeout(taskBuilder.Task, resources);
}

public Task Task { get; }

public void Dispose()
{
this.resources?.ForEach(d => d.Dispose());
this.resources = null;
}
}
}
15 changes: 8 additions & 7 deletions MedallionShell/Command.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,18 @@ internal Command() { }
public abstract void Kill();

/// <summary>
/// A convenience method for <code>command.Task.Wait()</code>
/// A convenience method for <code>command.Task.Wait()</code>. If the task faulted or was canceled,
/// this will throw the faulting <see cref="Exception"/> or <see cref="TaskCanceledException"/> rather than
/// the wrapped <see cref="AggregateException"/> thrown by <see cref="Task{TResult}.Result"/>
/// </summary>
public void Wait()
{
this.Task.Wait();
}
public void Wait() => this.Task.GetResultWithUnwrappedException();

/// <summary>
/// A convenience method for <code>command.Task.Result</code>
/// A convenience method for <code>command.Task.Result</code>. If the task faulted or was canceled,
/// this will throw the faulting <see cref="Exception"/> or <see cref="TaskCanceledException"/> rather than
/// the wrapped <see cref="AggregateException"/> thrown by <see cref="Task{TResult}.Result"/>
/// </summary>
public CommandResult Result { get { return this.Task.Result; } }
public CommandResult Result => this.Task.GetResultWithUnwrappedException();

/// <summary>
/// A <see cref="Task"/> representing the progress of this <see cref="Command"/>
Expand Down
2 changes: 1 addition & 1 deletion MedallionShell/MedallionShell.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
<PropertyGroup>
<TreatWarningsAsErrors>True</TreatWarningsAsErrors>
<TreatSpecificWarningsAsErrors />
<WarningLevel>1</WarningLevel>
<WarningLevel>4</WarningLevel>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<Optimize>True</Optimize>
<AssemblyVersion>1.3.0.0</AssemblyVersion>
Expand Down
75 changes: 38 additions & 37 deletions MedallionShell/ProcessCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ internal ProcessCommand(
bool throwOnError,
bool disposeOnExit,
TimeSpan timeout,
CancellationToken cancellationToken,
Encoding standardInputEncoding)
{
this.disposeOnExit = disposeOnExit;
this.process = new Process { StartInfo = startInfo, EnableRaisingEvents = true };

var processTask = CreateProcessTask(this.process, throwOnError: throwOnError, timeout: timeout);
var processTask = CreateProcessTask(this.process, throwOnError: throwOnError);

this.process.Start();

Expand Down Expand Up @@ -60,17 +61,25 @@ internal ProcessCommand(
{
this.processIdOrExceptionDispatchInfo = ExceptionDispatchInfo.Capture(processIdException);
}

this.task = this.CreateCombinedTask(processTask, ioTasks);
this.task = this.CreateCombinedTask(processTask.Task, timeout, cancellationToken, ioTasks);
}

private async Task<CommandResult> CreateCombinedTask(Task processTask, List<Task> ioTasks)
private async Task<CommandResult> CreateCombinedTask(
Task<int> processTask,
TimeSpan timeout,
CancellationToken cancellationToken,
List<Task> ioTasks)
{
int exitCode;
try
{
await processTask.ConfigureAwait(false);
exitCode = this.process.ExitCode;
// we only set up timeout and cancellation AFTER starting the process. This prevents a race
// condition where we immediately try to kill the process before having started it and then proceed to start it.
// While we could avoid starting at all in such cases, that would leave the command in a weird state (no PID, no streams, etc)
await this.HandleCancellationAndTimeout(processTask, cancellationToken, timeout).ConfigureAwait(false);

exitCode = await processTask.ConfigureAwait(false);
}
finally
{
Expand All @@ -85,6 +94,25 @@ private async Task<CommandResult> CreateCombinedTask(Task processTask, List<Task
return new CommandResult(exitCode, this);
}

private async Task HandleCancellationAndTimeout(Task<int> processTask, CancellationToken cancellationToken, TimeSpan timeout)
{
using (var cancellationOrTimeout = CancellationOrTimeout.TryCreate(cancellationToken, timeout))
{
if (cancellationOrTimeout != null)
{
// wait for either cancellation/timeout or the process to finish
var completed = await SystemTask.WhenAny(cancellationOrTimeout.Task, processTask).ConfigureAwait(false);
if (completed == cancellationOrTimeout.Task)
{
// if cancellation/timeout finishes first, kill the process
TryKillProcess(this.process);
// propagate cancellation or timeout exception
await completed.ConfigureAwait(false);
}
}
}
}

private readonly Process process;
public override System.Diagnostics.Process Process
{
Expand Down Expand Up @@ -167,9 +195,9 @@ public override void Kill()
TryKillProcess(this.process);
}

private static Task CreateProcessTask(Process process, bool throwOnError, TimeSpan timeout)
private static TaskCompletionSource<int> CreateProcessTask(Process process, bool throwOnError)
{
var taskCompletionSource = new TaskCompletionSource<bool>();
var taskCompletionSource = new TaskCompletionSource<int>();
process.Exited += (o, e) =>
{
Log.WriteLine("Received exited event from {0}", process.Id);
Expand All @@ -180,38 +208,11 @@ private static Task CreateProcessTask(Process process, bool throwOnError, TimeSp
}
else
{
taskCompletionSource.SetResult(true);
taskCompletionSource.SetResult(process.ExitCode);
}
};
return timeout != Timeout.InfiniteTimeSpan
? AddTimeout(taskCompletionSource.Task, process, timeout)
: taskCompletionSource.Task;
}

private static async Task AddTimeout(Task task, Process process, TimeSpan timeout)
{
using (var timeoutCleanupSource = new CancellationTokenSource())
{
// wait for either the given task or the timeout to complete
// http://stackoverflow.com/questions/4238345/asynchronously-wait-for-taskt-to-complete-with-timeout
var completed = await SystemTask.WhenAny(task, SystemTask.Delay(timeout, timeoutCleanupSource.Token)).ConfigureAwait(false);

// Task.WhenAny() swallows errors: wait for the completed task to propagate any errors that occurred
await completed.ConfigureAwait(false);

// if we timed out, kill the process
if (completed != task)
{
Log.WriteLine("Process timed out");
TryKillProcess(process);
throw new TimeoutException("Process killed after exceeding timeout of " + timeout);
}
else
{
// clean up the timeout
timeoutCleanupSource.Cancel();
}
}
return taskCompletionSource;
}

private static void TryKillProcess(Process process)
Expand Down
Loading

0 comments on commit e6219fc

Please sign in to comment.