diff --git a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json index b62d4c217f..8c12390426 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json @@ -73,7 +73,8 @@ "ForecastBySsa", "TextClassifcation", "SentenceSimilarity", - "ObjectDetection" + "ObjectDetection", + "QuestionAnswering" ] }, "nugetDependencies": { diff --git a/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json new file mode 100644 index 0000000000..83d7b0b538 --- /dev/null +++ b/src/Microsoft.ML.AutoML/CodeGen/question_answering_search_space.json @@ -0,0 +1,56 @@ +{ + "$schema": "./search-space-schema.json#", + "name": "question_answering_option", + "search_space": [ + { + "name": "ContextColumnName", + "type": "string", + "default": "Context" + }, + { + "name": "QuestionColumnName", + "type": "string", + "default": "Question" + }, + { + "name": "TrainingAnswerColumnName", + "type": "string", + "default": "TrainingAnswer" + }, + { + "name": "AnswerIndexStartColumnName", + "type": "string", + "default": "AnswerStart" + }, + { + "name": "ScoreColumnName", + "type": "string", + "default": "Score" + }, + { + "name": "predictedAnswerColumnName", + "type": "string", + "default": "Answer" + }, + { + "name": "BatchSize", + "type": "integer", + "default": 4 + }, + { + "name": "MaxEpochs", + "type": "integer", + "default": 10 + }, + { + "name": "TopKAnswers", + "type": "integer", + "default": 3 + }, + { + "name": "Architecture", + "type": "bertArchitecture", + "default": "BertArchitecture.Roberta" + } + ] +} diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json index e9e680e86e..5d28914734 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json @@ -146,7 +146,8 @@ "dnn_featurizer_image_option", "text_classification_option", "sentence_similarity_option", - "object_detection_option" + "object_detection_option", + "question_answering_option" ] }, "option_name": { @@ -210,7 +211,13 @@ "Steps", "MaxEpoch", "InitLearningRate", - "WeightDecay" + "WeightDecay", + "ContextColumnName", + "QuestionColumnName", + "TrainingAnswerColumnName", + "AnswerIndexStartColumnName", + "predictedAnswerColumnName", + "TopKAnswers" ] }, "option_type": { diff --git a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json index 9eab12f286..0ce5a45e37 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json +++ b/src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json @@ -532,6 +532,13 @@ "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ], "searchOption": "object_detection_option" }, + { + "functionName": "QuestionAnswering", + "estimatorTypes": [ "MultiClassification" ], + "nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ], + "usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ], + "searchOption": "question_answering_option" + }, { "functionName": "ForecastBySsa", "estimatorTypes": [ "Forecasting" ], diff --git a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj index 4f5c192cd6..6ecc501c11 100644 --- a/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj +++ b/src/Microsoft.ML.AutoML/Microsoft.ML.AutoML.csproj @@ -69,7 +69,7 @@ - + PreserveNewest diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs new file mode 100644 index 0000000000..64f99961ec --- /dev/null +++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/QuestionAnswering.cs @@ -0,0 +1,28 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. +using Microsoft.ML.TorchSharp; +using Microsoft.ML.TorchSharp.NasBert; +using Microsoft.ML.TorchSharp.Roberta; + +namespace Microsoft.ML.AutoML.CodeGen +{ + internal partial class QuestionAnsweringMulti + { + public override IEstimator BuildFromOption(MLContext context, QuestionAnsweringOption param) + { + return context.MulticlassClassification.Trainers.QuestionAnswer( + contextColumnName: param.ContextColumnName, + questionColumnName: param.QuestionColumnName, + trainingAnswerColumnName: param.TrainingAnswerColumnName, + answerIndexColumnName: param.AnswerIndexStartColumnName, + predictedAnswerColumnName: param.PredictedAnswerColumnName, + scoreColumnName: param.ScoreColumnName, + batchSize: param.BatchSize, + maxEpochs: param.MaxEpochs, + topK: param.TopKAnswers, + architecture: BertArchitecture.Roberta); + } + + } +}