Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,14 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
throw Host.Except($"Variable length input columns not supported");

if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType())
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
{
// If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided
// then throw an exception.
// This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32)))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
}

// If the column is one dimension we make sure that the total size of the Onnx shape matches.
// Compute the total size of the known dimensions of the shape.
Expand Down
68 changes: 68 additions & 0 deletions test/Microsoft.ML.Tests/OnnxConversionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1450,6 +1450,74 @@ public void CopyColumnsOnnxTest()
Done();
}

[Fact]
public void UseKeyDataViewTypeAsUInt32InOnnxInput()
{
// In this test an onnx model which expect a uin32 input column is applied to a KeyDataViewType input column
// This, is done as needed by NimbusML. For more context see: https://github.com/microsoft/NimbusML/issues/426

// Step 1: Load the Iris Dataset and apply a Value To Key Mapping to it.
// Save the resulting dataview in .idv format eliminating all hidden columns
var mlContext = new MLContext();
var loader = mlContext.Data.CreateTextLoader(
columns: new[]
{
new TextLoader.Column("Label", DataKind.String, 0),
new TextLoader.Column("SepalLength", DataKind.Single, 1),
new TextLoader.Column("SepalWidth", DataKind.Single, 2),
new TextLoader.Column("PetalLength", DataKind.Single, 3),
new TextLoader.Column("PetalWidth", DataKind.Single, 4)
},
hasHeader: false
);

string dataPath = GetDataPath("iris.txt");
var originalData = loader.Load(dataPath);
var pipeline1 = mlContext.Transforms.Conversion.MapValueToKey("Label");
var mappedData = pipeline1.Fit(originalData).Transform(originalData);
string mappedDataPath = @"C:\Users\anvelazq\Desktop\is22\data-new.idv";
using (FileStream stream = new FileStream(mappedDataPath, FileMode.Create))
Comment thread
antoniovs1029 marked this conversation as resolved.
mlContext.Data.SaveAsBinary(mappedData, stream, keepHidden: false);

// Step 2: Load back the saved .idv
// This IDataView will have a Label column of type KeyDataViewType
// It's necessary to do this, because if I were to use mappedData directly inside the next
// steps, then when saving the ONNX model, it would actually also save the ValueToKeyTransformer part
// and that wouldn't reproduce the scenario.
IDataView reloadedData = mlContext.Data.LoadFromBinary(mappedDataPath);

// Step 3: Create ONNX model which simply applies Identity to Label column
var pipeline2 = mlContext.Transforms.CopyColumns("Label", "Label");
var model = pipeline2.Fit(reloadedData);

var onnxModelPath = GetOutputPath("onnxmodel1-kdvt-as-uint32.onnx");
using (FileStream stream = new FileStream(onnxModelPath, FileMode.Create))
mlContext.Model.ConvertToOnnx(model, reloadedData, stream);

// Step 4: Get input and output names of model
var onnxProtoBufModel = mlContext.Model.ConvertToOnnxProtobuf(model, reloadedData);
string[] inputNames = onnxProtoBufModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray();
string[] outputNames = onnxProtoBufModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray();

// Step 5: Apply Onnx Model
var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath);
var onnxResult = onnxEstimator.Fit(reloadedData).Transform(reloadedData);

// Step 6: Compare results to an onnx model created using the mappedData IDataView
// Notice that this ONNX model would actually include the steps to do the ValueToKeyTransformer mapping
// And because of this, it can only be applied to reloadedData dataview, despite mappedData was used to create the model.
// If it's tried to apply this model to mappedData or reloadedData, it will throw an exception, since the ONNX model
// will expect a Label input of type string (which only originalData provides).
string onnxModelPath2 = GetOutputPath("onnxmodel2-kdvt-as-uint32.onnx");
using (FileStream stream = new FileStream(onnxModelPath2, FileMode.Create))
mlContext.Model.ConvertToOnnx(model, mappedData, stream);
var onnxEstimator2 = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath2);
var onnxResult2 = onnxEstimator2.Fit(originalData).Transform(originalData);

foreach (var name in outputNames)
CompareResults(name, name, onnxResult, onnxResult2);
}

[Fact]
public void FeatureSelectionOnnxTest()
{
Expand Down