Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 131 additions & 24 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.IO;
using System.Linq;
using System.Text;
using System.Xml.Schema;
using Microsoft.ML;
using Microsoft.ML.CommandLine;
using Microsoft.ML.Data;
Expand Down Expand Up @@ -509,18 +510,31 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
if (type.GetItemType() != expectedType)
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], expectedType.ToString(), type.ToString());
var originalShape = _parent.TFInputShapes[i];
var shape = originalShape.dims;
var originalShapeNdim = originalShape.ndim;
var originalShapeDims = originalShape.dims;

var colTypeDims = vecType.Dimensions.Select(dim => (int)dim).ToArray();
if (shape == null || (shape.Length == 0))
var inputDataShapeNdim = colTypeDims.Length;

if (originalShapeDims == null || (originalShapeDims.Length == 0))
{
_fullySpecifiedShapes[i] = new TensorShape(colTypeDims);
if (_parent._addBatchDimensionInput)
{
var l = new int[_fullySpecifiedShapes[i].ndim + 1];
l[0] = 1;
for (int ishape = 1; ishape < l.Length; ishape++)
l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1];
_fullySpecifiedShapes[i] = new TensorShape(l);
}
}
else
{
// If the column is one dimension we make sure that the total size of the TF shape matches.
// Compute the total size of the known dimensions of the shape.
int valCount = 1;
int numOfUnkDim = 0;
foreach (var s in shape)
foreach (var s in originalShapeDims)
{
if (s > 0)
valCount *= s;
Expand All @@ -532,29 +546,122 @@ public Mapper(TensorFlowTransformer parent, DataViewSchema inputSchema) :
if (typeValueCount % valCount != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");

// If the shape is multi-dimensional, we should be able to create the length of the vector by plugging
// in a single value for the unknown shapes. For example, if the shape is [?,?,3], then there should exist a value
// d such that d*d*3 is equal to the length of the input column.
var d = numOfUnkDim > 0 ? Math.Pow(typeValueCount / valCount, 1.0 / numOfUnkDim) : 0;
if (d - (int)d != 0)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of length {typeValueCount}.");
// If the AddBatchDimensionInput is set to true, one unknown dimension(i.e. batch size) is inferrable
var trueNumOfUnkDim = _parent._addBatchDimensionInput && originalShapeDims[0]==-1 ? (numOfUnkDim - 1) : numOfUnkDim;

// Fill in the unknown dimensions.
var originalShapeNdim = originalShape.ndim;
var originalShapeDims = originalShape.dims;
var l = new int[originalShapeNdim];
for (int ishape = 0; ishape < originalShapeNdim; ishape++)
l[ishape] = originalShapeDims[ishape] == -1 ? (int)d : originalShapeDims[ishape];
_fullySpecifiedShapes[i] = new TensorShape(l);
}
//all dimensions are known(except batch dimension which can be unknown). Eg:
// originalShape = [-1,2,2], AddBatchDimensionInput = true
// Inferred shape:[1,2,2]
// originalShape = [2,2]
// Inferred shape:[2,2]
if (trueNumOfUnkDim == 0)
{
int[] l = new int[originalShapeNdim];
int tensorShapeIndex = 0;
if (_parent._addBatchDimensionInput)
{
l[0] = 1;
tensorShapeIndex = 1;
}
for (; tensorShapeIndex < l.Length ; tensorShapeIndex++)
l[tensorShapeIndex] = originalShapeDims[tensorShapeIndex];
_fullySpecifiedShapes[i] = new TensorShape(l);
}
// One unknown dimension, which can be inferred from input. Eg:
// originalShape = [-1], input length=(5)
// Inferred shape:[5]
// originalShape = [1, -1, 2, 2], input length = 8.
// Inferred shape:[1,2,2,2]
// originalShape = [-1,-1, 2, 2], AddBatchDimensionInput = true and input length = 8.
// Inferred shape:[1,2,2,2]
else if (trueNumOfUnkDim == 1)
{
int[] l = new int[originalShapeNdim];
int tensorShapeIndex = 0;
//attempt to infer single missing dimension from passed vector input
int missingDim = typeValueCount / valCount;

if (_parent._addBatchDimensionInput)
{
l[0] = 1;
tensorShapeIndex = 1;
}

for (; tensorShapeIndex < originalShapeNdim; tensorShapeIndex++)
{
//Fill in tensor shape for known dims with expected tensor shape.
if (originalShapeDims[tensorShapeIndex] != -1)
{
l[tensorShapeIndex] = originalShapeDims[tensorShapeIndex];
}
else
{
l[tensorShapeIndex] = missingDim;
}
}
_fullySpecifiedShapes[i] = new TensorShape(l);

if (_parent._addBatchDimensionInput)
{
var l = new int[_fullySpecifiedShapes[i].ndim + 1];
l[0] = 1;
for (int ishape = 1; ishape < l.Length; ishape++)
l[ishape] = _fullySpecifiedShapes[i].dims[ishape - 1];
_fullySpecifiedShapes[i] = new TensorShape(l);
}
// For more than one unknown dimension, try to infer shape from input. Eg:
// originalShape = [-1,-1, 2, 2], AddBatchDimensionInput = false, inputShape = [1,2,2,2].
// Inferred shape:[1,2,2,2]
// originalShape = [-1, -1, -1, 2], AddBatchDimensionInput = true, inputShape = [2,2,2].
// Inferred shape:[1,2,2,2]
// originalShape = [1,-1,-1, 2], AddBatchDimensionInput = false, inputShape = [2,2,2].
// Inferred shape:[1,2,2,2]
// originalShape = [2,-1,-1, 2], AddBatchDimensionInput = true, inputShape = [2,2,2].
// Inferred shape:[2, 2, 2, 2]- use batch dim from the graph
else
{
//attempt to fill unknown dims from input shape
int[] l = new int[originalShapeNdim];
int inputDataIndex = 0;
int tensorShapeIndex = 0;

//If the input data passed has one dimension less than the expected input tensor shape
if (originalShapeNdim - inputDataShapeNdim == 1)
{

// If _addBatchDimensionInput option is set to false,
// and batch dimension is unknown, suggest setting it to true. eg:
// originalShape = [-1,-1, 2, 2], AddBatchDimensionInput = false, inputShape = [2,2,2].
if (!_parent._addBatchDimensionInput && originalShapeDims[0] == -1)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of shape ({String.Join(",", colTypeDims)}). Consider setting addBatchDimensionInput to true.");

// Eg:
// originalShape = [-1,-1,-1, 2], AddBatchDimensionInput = true, inputShape = [2,2,2].
// Inferred shape:[1,2,2,2]
else if (_parent._addBatchDimensionInput && originalShapeDims[0] == -1)
{
l[0] = 1;
tensorShapeIndex = 1;
}
}
// If the number of input data dimensions do not match with that of tensor dimensions, throw. Eg:
// originalShape = [1,-1,-1, 2], inputShape = [4,2].
// originalShape = [1,-1,-1, 2], inputShape = [8].
// originalShape = [-1, 2], inputShape = [2,2,2].
else if (originalShapeNdim != inputDataShapeNdim)
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of shape ({String.Join(",", colTypeDims)}).");

for (; tensorShapeIndex < originalShapeNdim; tensorShapeIndex++, inputDataIndex++)
{
//Fill in tensor shape for unknown dims with input data shape.
if (originalShapeDims[tensorShapeIndex] == -1)
{
l[tensorShapeIndex] = colTypeDims[inputDataIndex];
}
// If the tensor shape dim is known, assert that input data dim matches with
// expected tensor shape dim.
else if (originalShapeDims[tensorShapeIndex] == colTypeDims[inputDataIndex])
{
l[tensorShapeIndex] = originalShapeDims[tensorShapeIndex];
}
else
throw Contracts.Except($"Input shape mismatch: Input '{_parent.Inputs[i]}' has shape {originalShape.ToString()}, but input data is of shape ({String.Join(",", colTypeDims)}).");
}
_fullySpecifiedShapes[i] = new TensorShape(l);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ private class ShapeData

// Data will be passed as flat vector.
// Intended data shape [1, 2, 2, 3], model shape [1, None, None, 3]
[VectorType(12)]
[VectorType(1, 2, 2, 3)]
public float[] FourDim;

// Data will be passed as 4-D vector.
Expand Down