diff --git a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs index eba8029fdb..4f7b8fab08 100644 --- a/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs +++ b/src/Microsoft.ML.Vision/ImageClassificationTrainer.cs @@ -835,6 +835,7 @@ private void CacheFeaturizedImagesToDisk(IDataView input, string labelColumnName metrics.Bottleneck.DatasetUsed = dataset; while (cursor.MoveNext()) { + CheckAlive(); labelGetter(ref label); imageGetter(ref image); if (image.Length <= 0) @@ -888,6 +889,7 @@ private void CreateFeaturizedCacheFile(string cacheFilePath, int examples, int f foreach (var row in featurizedImages) { + CheckAlive(); writer.WriteLine(row.Item1 + "," + string.Join(",", row.Item2)); labels[0] = row.Item1; for (int index = 0; index < sizeof(long); index++) @@ -992,6 +994,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath, for (int epoch = 0; epoch < epochs; epoch += 1) { + CheckAlive(); // Train. TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset, metrics, labelTensorShape, featureTensorShape, batchSize, @@ -1119,6 +1122,19 @@ private void TrainAndEvaluateClassificationLayerCore(int epoch, float learningRa } } + private void CheckAlive() + { + try + { + Host.CheckAlive(); + } + catch(OperationCanceledException) + { + TryCleanupTemporaryWorkspace(); + throw; + } + } + private void TryCleanupTemporaryWorkspace() { if (_cleanupWorkspace && Directory.Exists(_options.WorkspacePath))