diff --git a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs index 135ad3b333..fc620beeb4 100644 --- a/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs +++ b/src/Microsoft.ML.EntryPoints/PermutationFeatureImportance.cs @@ -62,6 +62,8 @@ internal static IDataView GetMetrics( RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { + Contracts.Check(roleMappedData.Schema.Feature != null, "Feature column not found."); + Contracts.Check(roleMappedData.Schema.Label != null, "Label column not found."); IDataView result; if (predictor.PredictionKind == PredictionKind.BinaryClassification) result = GetBinaryMetrics(env, predictor, roleMappedData, input); @@ -252,6 +254,7 @@ private static IDataView GetRankingMetrics( RoleMappedData roleMappedData, PermutationFeatureImportanceArguments input) { + Contracts.Check(roleMappedData.Schema.Group != null, "Group ID column not found."); var roles = roleMappedData.Schema.GetColumnRoleNames(); var featureColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Feature.Value).First().Value; var labelColumnName = roles.Where(x => x.Key.Value == RoleMappedSchema.ColumnRole.Label.Value).First().Value;