diff --git a/build/Dependencies.props b/build/Dependencies.props
index 903b984a40..8bb8fceb5e 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -55,7 +55,7 @@
3.0.1
0.0.6-test
0.0.6-test
- 0.0.11-test
+ 0.0.12-test
0.0.6-test
4.6.1
1.2.7
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index 12d1ecf518..6c46f59b2d 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -529,6 +529,18 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
+ // This cover the 2-variable senario e.g. [?, ?, ?, C] where we can assume typeDims provides the information of [W, H, C]
+ // The shape will become [?, W, H, C]
+ var originalShapeDims = originalShape.dims;
+ var originalShapeNdim = originalShape.ndim;
+ if (numOfUnkDim == 3 && colTypeDims.Length == 3 && originalShapeNdim == numOfUnkDim + 1 && originalShapeDims[1] == -1)
+ {
+ originalShapeDims[1] = colTypeDims[0];
+ originalShapeDims[2] = colTypeDims[1];
+ valCount *= originalShapeDims[1] * originalShapeDims[2];
+ numOfUnkDim -= 2;
+ }
+
// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
// d such that d*d*3 is equal to the length of the input column.
@@ -537,8 +549,6 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
// Fill in the unknown dimensions.
- var originalShapeNdim = originalShape.ndim;
- var originalShapeDims = originalShape.dims;
var l = new int[originalShapeNdim];
for (int ishape = 0; ishape < originalShapeNdim; ishape++)
l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape];
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 5130391d23..d908ab247c 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -1900,5 +1900,40 @@ private static string GetTemporaryDirectory()
Directory.CreateDirectory(tempDirectory);
return tempDirectory;
}
+
+ [TensorFlowFact]
+ public void TensorflowPlaceholderShapeInferenceTest()
+ {
+ //frozen_model_variadic_input_shape.pb is modified by frozen_model.pb
+ //the shape of placeholder is changed from [?, w, h, c] to [?, ?, ?, c]
+ string modelLocation = "cifar_model/frozen_model_variadic_input_shape.pb";
+
+ int imageHeight = 32;
+ int imageWidth = 32;
+ string dataFile = GetDataPath("images/images.tsv");
+ string imageFolder = Path.GetDirectoryName(dataFile);
+
+ IDataView data = _mlContext.Data.LoadFromTextFile(dataFile, new[] {
+ new TextLoader.Column("imagePath", DataKind.String, 0),
+ new TextLoader.Column("name", DataKind.String, 1)
+ });
+
+ Tensorflow.TensorShape[] tfInputShape;
+
+ using (var tfModel = _mlContext.Model.LoadTensorFlowModel(modelLocation))
+ {
+ var pipeline = _mlContext.Transforms.LoadImages("Input", imageFolder, "imagePath")
+ .Append(_mlContext.Transforms.ResizeImages("Input", imageHeight, imageWidth))
+ .Append(_mlContext.Transforms.ExtractPixels("Input", interleavePixelColors: true))
+ .Append(tfModel.ScoreTensorFlowModel("Output", "Input"));
+
+ var transformer = pipeline.Fit(data);
+
+ tfInputShape = transformer.LastTransformer.TFInputShapes;
+ }
+
+ Assert.Equal(imageHeight, tfInputShape.ElementAt(0)[1].dims[0]);
+ Assert.Equal(imageWidth, tfInputShape.ElementAt(0)[2].dims[0]);
+ }
}
}