Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/Microsoft.ML.PipelineInference/RecipeInference.cs
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ public static SuggestedRecipe.SuggestedLearner[] AllowedLearners(IHostEnvironmen
var type = typeof(CommonInputs.ITrainerInput);
var trainerTypes = typeof(Experiment).Assembly.GetTypes()
.Where(p => type.IsAssignableFrom(p) &&
p.Name.Equals(MacroUtils.GetTrainerName(trainerKind)));
MacroUtils.IsTrainerOfKind(p, trainerKind));

foreach (var tt in trainerTypes)
{
Expand All @@ -516,7 +516,7 @@ public static SuggestedRecipe.SuggestedLearner[] AllowedLearners(IHostEnvironmen
var sl = new SuggestedRecipe.SuggestedLearner
{
PipelineNode = new TrainerPipelineNode(epInputObj, sweepParams),
LearnerName = tt.Namespace
LearnerName = tt.Name
};

if (sl.PipelineNode != null && availableLearnersList.FirstOrDefault(l=> l.Name.Equals(sl.PipelineNode.GetEpName())) != null)
Expand Down
32 changes: 24 additions & 8 deletions src/Microsoft.ML/Runtime/EntryPoints/MacroUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureBinaryClassifierTrainer,
new TaskInformationBundle {
TrainerFunctionName = "TrainBinary",
TrainerFunctionName = "BinaryClassifier",
TrainerSignatureType = typeof(SignatureBinaryClassifierTrainer),
EvaluatorInput = settings => new Models.BinaryClassificationEvaluator
{
Expand All @@ -71,7 +71,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureMultiClassClassifierTrainer,
new TaskInformationBundle{
TrainerFunctionName = "TrainMultiClass",
TrainerFunctionName = "Classifier",
TrainerSignatureType = typeof(SignatureMultiClassClassifierTrainer),
EvaluatorInput = settings => new Models.ClassificationEvaluator
{
Expand All @@ -87,7 +87,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureRankerTrainer,
new TaskInformationBundle {
TrainerFunctionName = "TrainRanking",
TrainerFunctionName = "Ranker",
TrainerSignatureType = typeof(SignatureRankerTrainer),
EvaluatorInput = settings => new Models.RankerEvaluator
{
Expand All @@ -103,7 +103,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureRegressorTrainer,
new TaskInformationBundle{
TrainerFunctionName = "TrainRegression",
TrainerFunctionName = "Regressor",
TrainerSignatureType = typeof(SignatureRegressorTrainer),
EvaluatorInput = settings => new Models.RegressionEvaluator
{
Expand All @@ -119,7 +119,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureMultiOutputRegressorTrainer,
new TaskInformationBundle {
TrainerFunctionName = "TrainMultiRegression",
TrainerFunctionName = "MultiOutputRegressor",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MultiOutputRegressor [](start = 47, length = 20)

You said something about a name suffix... is it a concern that the suffixes are not unique? So Regressor above would be a suffix of any MultiOutputRegressor. We previously worked around this problem I guess by having every suffix start with Train, but that solution has gone out the window I suppose.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar notes for Classifier vs. BinaryClassifier.


In reply to: 187424111 [](ancestors = 187424111)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK now I know you've special cased this. :) Hmmm I'd still rather have a different solution.


In reply to: 187424194 [](ancestors = 187424194,187424111)

TrainerSignatureType = typeof(SignatureMultiOutputRegressorTrainer),
EvaluatorInput = settings => new Models.MultiOutputRegressionEvaluator
{
Expand All @@ -135,7 +135,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureAnomalyDetectorTrainer,
new TaskInformationBundle {
TrainerFunctionName = "TrainAnomalyDetection",
TrainerFunctionName = "AnomalyDetector",
TrainerSignatureType = typeof(SignatureAnomalyDetectorTrainer),
EvaluatorInput = settings => new Models.AnomalyDetectionEvaluator
{
Expand All @@ -151,7 +151,7 @@ private static Dictionary<TrainerKinds, TaskInformationBundle>
{
TrainerKinds.SignatureClusteringTrainer,
new TaskInformationBundle {
TrainerFunctionName = "TrainClustering",
TrainerFunctionName = "Clusterer",
TrainerSignatureType = typeof(SignatureClusteringTrainer),
EvaluatorInput = settings => new Models.ClusterEvaluator
{
Expand Down Expand Up @@ -186,13 +186,29 @@ public static TrainerKinds SignatureTypeToTrainerKind(Type sigType)
public static TrainerKinds[] SignatureTypesToTrainerKinds(IEnumerable<Type> sigTypes) =>
sigTypes.Select(SignatureTypeToTrainerKind).ToArray();

public static string GetTrainerName(TrainerKinds kind) => TrainerKindDict[kind].TrainerFunctionName;
private static string GetTrainerName(TrainerKinds kind) => TrainerKindDict[kind].TrainerFunctionName;

public static T TrainerKindApiValue<T>(TrainerKinds trainerKind)
{
if (Enum.GetName(typeof(TrainerKinds), trainerKind) is string name)
return (T)Enum.Parse(typeof(T), name);
throw new Exception($"Could not interpret enum value: {trainerKind}");
}

public static bool IsTrainerOfKind(Type type, TrainerKinds trainerKind)
{
if (type == typeof(Trainers.BinaryLogisticRegressor))
return trainerKind == TrainerKinds.SignatureBinaryClassifierTrainer;
if (type == typeof(Trainers.LogisticRegressor))
return trainerKind == TrainerKinds.SignatureMultiClassClassifierTrainer;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

God, this is so ugly.
Is there a way to cast EntryPoint type to real type? Maybe you can use EntryPointCatalog to find mapping, and after you get real type you can maybe get prediction type, or get kind through reflection?
This is unmaintainable.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better solution might be harder to imagine though. We might let this slide for now since this change will fix things that are broken now, but we ought definitely to file a new issue so we can imagine something a bit less fragile than this system.


In reply to: 187421786 [](ancestors = 187421786)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be addressed when we change the mechanism to use the EntryPoint attribute?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The entry-point attribute may indeed be the best place to put a category like this -- something akin to the old-style PredictionKind enum, but not that please. :D :D :D


In reply to: 187457279 [](ancestors = 187457279)


if (trainerKind != TrainerKinds.SignatureMultiClassClassifierTrainer && trainerKind != TrainerKinds.SignatureMultiOutputRegressorTrainer)
return type.Name.EndsWith(GetTrainerName(trainerKind));

if (trainerKind == TrainerKinds.SignatureMultiClassClassifierTrainer)
return type.Name.EndsWith(GetTrainerName(trainerKind)) && !type.Name.EndsWith(GetTrainerName(TrainerKinds.SignatureBinaryClassifierTrainer));

return type.Name.EndsWith(GetTrainerName(trainerKind)) && !type.Name.EndsWith(GetTrainerName(TrainerKinds.SignatureRegressorTrainer));
}
}
}