diff --git a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
index b62d4c217f..8c12390426 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json
@@ -73,7 +73,8 @@
"ForecastBySsa",
"TextClassifcation",
"SentenceSimilarity",
- "ObjectDetection"
+ "ObjectDetection",
+ "QuestionAnswering"
]
},
"nugetDependencies": {
diff --git a/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json
new file mode 100644
index 0000000000..83d7b0b538
--- /dev/null
+++ b/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json
@@ -0,0 +1,56 @@
+{
+ "$schema": "./search-space-schema.json#",
+ "name": "question_answering_option",
+ "search_space": [
+ {
+ "name": "ContextColumnName",
+ "type": "string",
+ "default": "Context"
+ },
+ {
+ "name": "QuestionColumnName",
+ "type": "string",
+ "default": "Question"
+ },
+ {
+ "name": "TrainingAnswerColumnName",
+ "type": "string",
+ "default": "TrainingAnswer"
+ },
+ {
+ "name": "AnswerIndexStartColumnName",
+ "type": "string",
+ "default": "AnswerStart"
+ },
+ {
+ "name": "ScoreColumnName",
+ "type": "string",
+ "default": "Score"
+ },
+ {
+ "name": "predictedAnswerColumnName",
+ "type": "string",
+ "default": "Answer"
+ },
+ {
+ "name": "BatchSize",
+ "type": "integer",
+ "default": 4
+ },
+ {
+ "name": "MaxEpochs",
+ "type": "integer",
+ "default": 10
+ },
+ {
+ "name": "TopKAnswers",
+ "type": "integer",
+ "default": 3
+ },
+ {
+ "name": "Architecture",
+ "type": "bertArchitecture",
+ "default": "BertArchitecture.Roberta"
+ }
+ ]
+}
diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
index e9e680e86e..5d28914734 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json
@@ -146,7 +146,8 @@
"dnn_featurizer_image_option",
"text_classification_option",
"sentence_similarity_option",
- "object_detection_option"
+ "object_detection_option",
+ "question_answering_option"
]
},
"option_name": {
@@ -210,7 +211,13 @@
"Steps",
"MaxEpoch",
"InitLearningRate",
- "WeightDecay"
+ "WeightDecay",
+ "ContextColumnName",
+ "QuestionColumnName",
+ "TrainingAnswerColumnName",
+ "AnswerIndexStartColumnName",
+ "predictedAnswerColumnName",
+ "TopKAnswers"
]
},
"option_type": {
diff --git a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
index 9eab12f286..0ce5a45e37 100644
--- a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
+++ b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json
@@ -532,6 +532,13 @@
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
"searchOption": "object_detection_option"
},
+ {
+ "functionName": "QuestionAnswering",
+ "estimatorTypes": [ "MultiClassification" ],
+ "nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
+ "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
+ "searchOption": "question_answering_option"
+ },
{
"functionName": "ForecastBySsa",
"estimatorTypes": [ "Forecasting" ],
diff --git a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
index 4f5c192cd6..6ecc501c11 100644
--- a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
+++ b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj
@@ -69,7 +69,7 @@
-
+
PreserveNewest
diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs
new file mode 100644
index 0000000000..64f99961ec
--- /dev/null
+++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs
@@ -0,0 +1,28 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+using Microsoft.ML.TorchSharp;
+using Microsoft.ML.TorchSharp.NasBert;
+using Microsoft.ML.TorchSharp.Roberta;
+
+namespace Microsoft.ML.AutoML.CodeGen
+{
+ internal partial class QuestionAnsweringMulti
+ {
+ public override IEstimator BuildFromOption(MLContext context, QuestionAnsweringOption param)
+ {
+ return context.MulticlassClassification.Trainers.QuestionAnswer(
+ contextColumnName: param.ContextColumnName,
+ questionColumnName: param.QuestionColumnName,
+ trainingAnswerColumnName: param.TrainingAnswerColumnName,
+ answerIndexColumnName: param.AnswerIndexStartColumnName,
+ predictedAnswerColumnName: param.PredictedAnswerColumnName,
+ scoreColumnName: param.ScoreColumnName,
+ batchSize: param.BatchSize,
+ maxEpochs: param.MaxEpochs,
+ topK: param.TopKAnswers,
+ architecture: BertArchitecture.Roberta);
+ }
+
+ }
+}