diff --git a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs index bad3d28d6a..2f5c220fe9 100644 --- a/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs +++ b/src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs @@ -331,19 +331,21 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string var nameX = featureColumn; // Compute X^2 from X - var nameX2 = ctx.AddIntermediateVariable(null, "X2", true); + var nameX2 = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1), "X2"); var reduceNodeX2 = ctx.CreateNode("ReduceSumSquare", nameX, nameX2, ctx.GetNodeName("ReduceSumSquare"), ""); + reduceNodeX2.AddAttribute("axes", new long[] { 1 }); // Compute -2XC^T. Note that Gemm always takes three inputs. Since we only have two here, // a dummy one, named zero, is created. + var dataViewType = new VectorDataViewType(NumberDataViewType.Single, _centroids.Length); var zeroName = ctx.AddInitializer(new float[] { 0f }, null, "zero"); - var nameXC2 = ctx.AddIntermediateVariable(null, "XC2", true); + var nameXC2 = ctx.AddIntermediateVariable(dataViewType, "XC2"); var gemmNodeXC2 = ctx.CreateNode("Gemm", new[] { nameX, nameC, zeroName }, new[] { nameXC2 }, ctx.GetNodeName("Gemm"), ""); gemmNodeXC2.AddAttribute("alpha", -2f); gemmNodeXC2.AddAttribute("transB", 1); // Compute Z = X^2 - 2XC^T - var nameZ = ctx.AddIntermediateVariable(null, "Z", true); + var nameZ = ctx.AddIntermediateVariable(dataViewType, "Z"); var addNodeZ = ctx.CreateNode("Add", new[] { nameX2, nameXC2 }, new[] { nameZ }, ctx.GetNodeName("Add"), ""); // Compute Y = Z + C^2 diff --git a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt index 50624d3304..a1324ac03d 100644 --- a/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt +++ b/test/BaselineOutput/Common/Onnx/Cluster/BreastCancer/Kmeans.txt @@ -56,7 +56,16 @@ "X2" ], "name": "ReduceSumSquare", - "opType": "ReduceSumSquare" + "opType": "ReduceSumSquare", + "attribute": [ + { + "name": "axes", + "ints": [ + "1" + ], + "type": "INTS" + } + ] }, { "input": [ @@ -377,6 +386,60 @@ } } }, + { + "name": "X2", + "type": { + "tensorType": { + "elemType": 1, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "1" + } + ] + } + } + } + }, + { + "name": "XC2", + "type": { + "tensorType": { + "elemType": 1, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "4" + } + ] + } + } + } + }, + { + "name": "Z", + "type": { + "tensorType": { + "elemType": 1, + "shape": { + "dim": [ + { + "dimValue": "-1" + }, + { + "dimValue": "4" + } + ] + } + } + } + }, { "name": "Features.output", "type": {