Skip to content
73 changes: 73 additions & 0 deletions src/Microsoft.ML.Dnn/ImageClassificationTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option

if (_session.graph.OperationByName(_labelTensor.name.Split(':')[0]) == null)
throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
if (options.EarlyStoppingCriteria != null && options.ValidationSet == null && options.TestOnTrainSet == false)
throw Host.ExceptParam(nameof(options.EarlyStoppingCriteria), $"Early stopping enabled but unable to find a validation" +
$" set and/or train set testing disabled. Please disable early stopping or either provide a validation set or enable train set training.");
}

private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth)
Expand Down Expand Up @@ -381,6 +384,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
float crossentropy = 0;
for (int epoch = 0; epoch < epochs; epoch += 1)
{
batchIndex = 0;
metrics.Train.Accuracy = 0;
metrics.Train.CrossEntropy = 0;
metrics.Train.BatchProcessedCount = 0;
Expand Down Expand Up @@ -432,6 +436,42 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
}
}

//Process last incomplete batch
if (batchIndex > 0)
{
featureTensorShape[0] = batchIndex;
featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex;
labelTensorShape[0] = batchIndex;
labelBatchSizeInBytes = sizeof(long) * batchIndex;
runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
.Run();

metrics.Train.BatchProcessedCount += 1;

if (options.TestOnTrainSet && statisticsCallback != null)
{
var outputTensors = testEvalRunner
.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
.Run();

outputTensors[0].ToScalar<float>(ref accuracy);
outputTensors[1].ToScalar<float>(ref crossentropy);
metrics.Train.Accuracy += accuracy;
metrics.Train.CrossEntropy += crossentropy;

outputTensors[0].Dispose();
outputTensors[1].Dispose();
}

batchIndex = 0;
featureTensorShape[0] = batchSize;
featureBatchSizeInBytes = sizeof(float) * featureBatch.Length;
labelTensorShape[0] = batchSize;
labelBatchSizeInBytes = sizeof(long) * batchSize;
}

if (options.TestOnTrainSet && statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
Expand All @@ -443,7 +483,15 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
}

if (validationSet == null)
{
//Early stopping check
if (options.EarlyStoppingCriteria != null)
{
if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train))
break;
}
continue;
}

batchIndex = 0;
metrics.Train.BatchProcessedCount = 0;
Expand Down Expand Up @@ -481,6 +529,31 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
}
}

//Process last incomplete batch
if(batchIndex > 0)
{
featureTensorShape[0] = batchIndex;
featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex;
labelTensorShape[0] = batchIndex;
labelBatchSizeInBytes = sizeof(long) * batchIndex;
var outputTensors = validationEvalRunner
.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
.Run();

outputTensors[0].ToScalar<float>(ref accuracy);
metrics.Train.Accuracy += accuracy;
metrics.Train.BatchProcessedCount += 1;
batchIndex = 0;

featureTensorShape[0] = batchSize;
featureBatchSizeInBytes = sizeof(float) * featureBatch.Length;
labelTensorShape[0] = batchSize;
labelBatchSizeInBytes = sizeof(long) * batchSize;

outputTensors[0].Dispose();
}

if (statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
Expand Down