Skip to content
16 changes: 16 additions & 0 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName
metrics.Bottleneck.DatasetUsed = dataset;
while (cursor.MoveNext())
{
CheckAlive();
labelGetter(ref label);
imageGetter(ref image);
if (image.Length <= 0)
Expand Down Expand Up @@ -888,6 +889,7 @@ private void CreateFeaturizedCacheFile(string cacheFilePath, int examples, int f

foreach (var row in featurizedImages)
{
CheckAlive();
writer.WriteLine(row.Item1 + "," + string.Join(",", row.Item2));
labels[0] = row.Item1;
for (int index = 0; index < sizeof(long); index++)
Expand Down Expand Up @@ -992,6 +994,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

for (int epoch = 0; epoch < epochs; epoch += 1)
{
CheckAlive();
// Train.
TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset,
metrics, labelTensorShape, featureTensorShape, batchSize,
Expand Down Expand Up @@ -1119,6 +1122,19 @@ private void TrainAndEvaluateClassificationLayerCore(int epoch, float learningRa
}
}

private void CheckAlive()
{
try
{
Host.CheckAlive();
}
catch(OperationCanceledException)
{
TryCleanupTemporaryWorkspace();
throw;
}
}

private void TryCleanupTemporaryWorkspace()
{
if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))
Expand Down