diff --git a/src/Microsoft.ML.Core/SearchSpace/BoolearnChoiceAttribute.cs b/src/Microsoft.ML.Core/SearchSpace/BoolearnChoiceAttribute.cs
new file mode 100644
index 0000000000..a9a50de5f0
--- /dev/null
+++ b/src/Microsoft.ML.Core/SearchSpace/BoolearnChoiceAttribute.cs
@@ -0,0 +1,33 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.SearchSpace;
+
+///
+/// Boolean choice attribute
+///
+[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
+public sealed class BooleanChoiceAttribute : Attribute
+{
+ ///
+ /// Create a .
+ ///
+ public BooleanChoiceAttribute()
+ {
+ DefaultValue = true;
+ }
+
+ ///
+ /// Create a with default value.
+ ///
+ /// default value for this option.
+ public BooleanChoiceAttribute(bool defaultValue)
+ {
+ DefaultValue = defaultValue;
+ }
+
+ public bool DefaultValue { get; }
+}
diff --git a/src/Microsoft.ML.Core/SearchSpace/ChoiceAttribute.cs b/src/Microsoft.ML.Core/SearchSpace/ChoiceAttribute.cs
new file mode 100644
index 0000000000..24db500703
--- /dev/null
+++ b/src/Microsoft.ML.Core/SearchSpace/ChoiceAttribute.cs
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Diagnostics.Contracts;
+using System.Linq;
+
+namespace Microsoft.ML.SearchSpace;
+
+///
+/// Choice attribute
+///
+[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
+public sealed class ChoiceAttribute : Attribute
+{
+ ///
+ /// Create a with .
+ ///
+ public ChoiceAttribute(params object[] candidates)
+ {
+ var candidatesType = candidates.Select(o => o.GetType()).Distinct();
+ Contract.Assert(candidatesType.Count() == 1, "multiple candidates type detected");
+ this.Candidates = candidates;
+ this.DefaultValue = null;
+ }
+
+ ///
+ /// Create a with and .
+ ///
+ public ChoiceAttribute(object[] candidates, object defaultValue)
+ {
+ var candidatesType = candidates.Select(o => o.GetType()).Distinct();
+ Contract.Assert(candidatesType.Count() == 1, "multiple candidates type detected");
+ Contract.Assert(candidatesType.First() == defaultValue.GetType(), "candidates type doesn't match with defaultValue type");
+
+ this.Candidates = candidates;
+ this.DefaultValue = defaultValue;
+ }
+
+ ///
+ /// Get the candidates of this option.
+ ///
+ public object[] Candidates { get; }
+
+ ///
+ /// Get the default value of this option.
+ ///
+ public object DefaultValue { get; }
+}
diff --git a/src/Microsoft.ML.Core/SearchSpace/NestOptionAttribute.cs b/src/Microsoft.ML.Core/SearchSpace/NestOptionAttribute.cs
new file mode 100644
index 0000000000..2a46530ed8
--- /dev/null
+++ b/src/Microsoft.ML.Core/SearchSpace/NestOptionAttribute.cs
@@ -0,0 +1,21 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.SearchSpace;
+
+///
+/// attribution class for nest option.
+///
+[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
+public sealed class NestOptionAttribute : Attribute
+{
+ ///
+ /// Create an .
+ ///
+ public NestOptionAttribute()
+ {
+ }
+}
diff --git a/src/Microsoft.ML.Core/SearchSpace/RangeAttribute.cs b/src/Microsoft.ML.Core/SearchSpace/RangeAttribute.cs
new file mode 100644
index 0000000000..f907650be7
--- /dev/null
+++ b/src/Microsoft.ML.Core/SearchSpace/RangeAttribute.cs
@@ -0,0 +1,88 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.SearchSpace;
+
+///
+/// Range attribute
+///
+[AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
+public sealed class RangeAttribute : Attribute
+{
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(double min, double max, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = null;
+ this.LogBase = logBase;
+ }
+
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(double min, double max, double init, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = init;
+ this.LogBase = logBase;
+ }
+
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(int min, int max, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = null;
+ this.LogBase = logBase;
+ }
+
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(int min, int max, int init, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = init;
+ this.LogBase = logBase;
+ }
+
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(float min, float max, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = null;
+ this.LogBase = logBase;
+ }
+
+ ///
+ /// Create a
+ ///
+ public RangeAttribute(float min, float max, float init, bool logBase = false)
+ {
+ this.Min = min;
+ this.Max = max;
+ this.Init = init;
+ this.LogBase = logBase;
+ }
+
+ public object Min { get; }
+
+ public object Max { get; }
+
+ public object Init { get; }
+
+ public bool LogBase { get; }
+}
diff --git a/src/Microsoft.ML.SearchSpace/Assembly.cs b/src/Microsoft.ML.SearchSpace/Assembly.cs
index 39b50da899..880329026d 100644
--- a/src/Microsoft.ML.SearchSpace/Assembly.cs
+++ b/src/Microsoft.ML.SearchSpace/Assembly.cs
@@ -8,3 +8,7 @@
[assembly: InternalsVisibleTo("Microsoft.ML.SearchSpace.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
[assembly: InternalsVisibleTo("Microsoft.ML.AutoML, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
+[assembly: TypeForwardedTo(typeof(Microsoft.ML.SearchSpace.BooleanChoiceAttribute))]
+[assembly: TypeForwardedTo(typeof(Microsoft.ML.SearchSpace.ChoiceAttribute))]
+[assembly: TypeForwardedTo(typeof(Microsoft.ML.SearchSpace.NestOptionAttribute))]
+[assembly: TypeForwardedTo(typeof(Microsoft.ML.SearchSpace.RangeAttribute))]
diff --git a/src/Microsoft.ML.SearchSpace/Attribute/BooleanChoiceAttribute.cs b/src/Microsoft.ML.SearchSpace/Attribute/BooleanChoiceAttribute.cs
deleted file mode 100644
index be420ec22e..0000000000
--- a/src/Microsoft.ML.SearchSpace/Attribute/BooleanChoiceAttribute.cs
+++ /dev/null
@@ -1,37 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Globalization;
-using System.Linq;
-using Microsoft.ML.SearchSpace.Option;
-
-namespace Microsoft.ML.SearchSpace
-{
- ///
- /// Boolean choice attribute
- ///
- [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
- public sealed class BooleanChoiceAttribute : Attribute
- {
- ///
- /// Create a .
- ///
- public BooleanChoiceAttribute()
- {
- Option = new ChoiceOption(true, false);
- }
-
- ///
- /// Create a with default value.
- ///
- /// default value for this option.
- public BooleanChoiceAttribute(bool defaultValue)
- {
- Option = new ChoiceOption(new object[] { true, false }, defaultChoice: defaultValue);
- }
-
- internal ChoiceOption Option { get; }
- }
-}
diff --git a/src/Microsoft.ML.SearchSpace/Attribute/ChoiceAttribute.cs b/src/Microsoft.ML.SearchSpace/Attribute/ChoiceAttribute.cs
deleted file mode 100644
index ccb13cb16c..0000000000
--- a/src/Microsoft.ML.SearchSpace/Attribute/ChoiceAttribute.cs
+++ /dev/null
@@ -1,46 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Diagnostics.Contracts;
-using System.Globalization;
-using System.Linq;
-using Microsoft.ML.SearchSpace.Option;
-
-namespace Microsoft.ML.SearchSpace
-{
- ///
- /// attribution class for . The property or field it applys to will be treated as in .
- ///
- ///
- ///
- [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
- public sealed class ChoiceAttribute : Attribute
- {
- ///
- /// Create a with .
- ///
- public ChoiceAttribute(params object[] candidates)
- {
- var candidatesType = candidates.Select(o => o.GetType()).Distinct();
- Contract.Assert(candidatesType.Count() == 1, "multiple candidates type detected");
-
- Option = new ChoiceOption(candidates);
- }
-
- ///
- /// Create a with and .
- ///
- public ChoiceAttribute(object[] candidates, object defaultValue)
- {
- var candidatesType = candidates.Select(o => o.GetType()).Distinct();
- Contract.Assert(candidatesType.Count() == 1, "multiple candidates type detected");
- Contract.Assert(candidatesType.First() == defaultValue.GetType(), "candidates type doesn't match with defaultValue type");
-
- Option = new ChoiceOption(candidates, defaultValue);
- }
-
- internal ChoiceOption Option { get; }
- }
-}
diff --git a/src/Microsoft.ML.SearchSpace/Attribute/NestOptionAttribute.cs b/src/Microsoft.ML.SearchSpace/Attribute/NestOptionAttribute.cs
deleted file mode 100644
index c7df028b20..0000000000
--- a/src/Microsoft.ML.SearchSpace/Attribute/NestOptionAttribute.cs
+++ /dev/null
@@ -1,22 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-
-namespace Microsoft.ML.SearchSpace
-{
- ///
- /// attribution class for nest option.
- ///
- [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
- public sealed class NestOptionAttribute : Attribute
- {
- ///
- /// Create an .
- ///
- public NestOptionAttribute()
- {
- }
- }
-}
diff --git a/src/Microsoft.ML.SearchSpace/Attribute/RangeAttribute.cs b/src/Microsoft.ML.SearchSpace/Attribute/RangeAttribute.cs
deleted file mode 100644
index d0acbea013..0000000000
--- a/src/Microsoft.ML.SearchSpace/Attribute/RangeAttribute.cs
+++ /dev/null
@@ -1,66 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using Microsoft.ML.SearchSpace.Option;
-
-namespace Microsoft.ML.SearchSpace
-{
- ///
- /// attribution class for , and .
- ///
- [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = false, AllowMultiple = false)]
- public sealed class RangeAttribute : Attribute
- {
- ///
- /// Create a for .
- ///
- public RangeAttribute(double min, double max, bool logBase = false)
- {
- Option = new UniformDoubleOption(min, max, logBase);
- }
-
- ///
- /// Create a for .
- ///
- public RangeAttribute(double min, double max, double init, bool logBase = false)
- {
- Option = new UniformDoubleOption(min, max, logBase, init);
- }
-
- ///
- /// Create a for .
- ///
- public RangeAttribute(int min, int max, bool logBase = false)
- {
- Option = new UniformIntOption(min, max, logBase);
- }
-
- ///
- /// Create a for .
- ///
- public RangeAttribute(int min, int max, int init, bool logBase = false)
- {
- Option = new UniformIntOption(min, max, logBase, init);
- }
-
- ///
- /// Create a for .
- ///
- public RangeAttribute(float min, float max, bool logBase = false)
- {
- Option = new UniformSingleOption(min, max, logBase);
- }
-
- ///
- /// Create a for .
- ///
- public RangeAttribute(float min, float max, float init, bool logBase = false)
- {
- Option = new UniformSingleOption(min, max, logBase, init);
- }
-
- internal OptionBase Option { get; }
- }
-}
diff --git a/src/Microsoft.ML.SearchSpace/Microsoft.ML.SearchSpace.csproj b/src/Microsoft.ML.SearchSpace/Microsoft.ML.SearchSpace.csproj
index 155feb6f84..0390aa99df 100644
--- a/src/Microsoft.ML.SearchSpace/Microsoft.ML.SearchSpace.csproj
+++ b/src/Microsoft.ML.SearchSpace/Microsoft.ML.SearchSpace.csproj
@@ -9,6 +9,7 @@
+
diff --git a/src/Microsoft.ML.SearchSpace/SearchSpace.cs b/src/Microsoft.ML.SearchSpace/SearchSpace.cs
index ea29cc06a3..10e680ac8e 100644
--- a/src/Microsoft.ML.SearchSpace/SearchSpace.cs
+++ b/src/Microsoft.ML.SearchSpace/SearchSpace.cs
@@ -214,9 +214,18 @@ private Dictionary GetOptionsFromField(Type typeInfo)
OptionBase option = attributes.First() switch
{
- ChoiceAttribute choice => choice.Option,
- RangeAttribute range => range.Option,
- BooleanChoiceAttribute booleanChoice => booleanChoice.Option,
+ ChoiceAttribute choice => choice.DefaultValue == null ? new ChoiceOption(choice.Candidates) : new ChoiceOption(choice.Candidates, defaultChoice: choice.DefaultValue),
+ RangeAttribute range => (range.Min, range.Max, range.Init, range.LogBase) switch
+ {
+ (double min, double max, double init, bool logBase) => new UniformDoubleOption(min, max, logBase, init),
+ (double min, double max, null, bool logBase) => new UniformDoubleOption(min, max, logBase),
+ (int min, int max, int init, bool logBase) => new UniformIntOption(min, max, logBase, init),
+ (int min, int max, null, bool logBase) => new UniformIntOption(min, max, logBase),
+ (float min, float max, float init, bool logBase) => new UniformSingleOption(min, max, logBase, init),
+ (float min, float max, null, bool logBase) => new UniformSingleOption(min, max, logBase),
+ _ => throw new NotImplementedException(),
+ },
+ BooleanChoiceAttribute booleanChoice => new ChoiceOption(new object[] { true, false }, defaultChoice: booleanChoice.DefaultValue),
NestOptionAttribute nest => GetSearchSpaceOptionFromType(field.FieldType),
_ => throw new NotImplementedException(),
};
@@ -252,9 +261,18 @@ private Dictionary GetOptionsFromProperty(Type typeInfo)
OptionBase option = attributes.First() switch
{
- ChoiceAttribute choice => choice.Option,
- RangeAttribute range => range.Option,
- BooleanChoiceAttribute booleanChoice => booleanChoice.Option,
+ ChoiceAttribute choice => choice.DefaultValue == null ? new ChoiceOption(choice.Candidates) : new ChoiceOption(choice.Candidates, defaultChoice: choice.DefaultValue),
+ RangeAttribute range => (range.Min, range.Max, range.Init, range.LogBase) switch
+ {
+ (double min, double max, double init, bool logBase) => new UniformDoubleOption(min, max, logBase, init),
+ (double min, double max, null, bool logBase) => new UniformDoubleOption(min, max, logBase),
+ (int min, int max, int init, bool logBase) => new UniformIntOption(min, max, logBase, init),
+ (int min, int max, null, bool logBase) => new UniformIntOption(min, max, logBase),
+ (float min, float max, float init, bool logBase) => new UniformSingleOption(min, max, logBase, init),
+ (float min, float max, null, bool logBase) => new UniformSingleOption(min, max, logBase),
+ _ => throw new NotImplementedException(),
+ },
+ BooleanChoiceAttribute booleanChoice => new ChoiceOption(new object[] { true, false }, defaultChoice: booleanChoice.DefaultValue),
NestOptionAttribute nest => GetSearchSpaceOptionFromType(property.PropertyType),
_ => throw new NotImplementedException(),
};
@@ -274,7 +292,7 @@ private void CheckOptionType(object attribute, string optionName, Type type)
return;
}
- if (attribute is RangeAttribute range && (range.Option is UniformDoubleOption || range.Option is UniformSingleOption))
+ if (attribute is RangeAttribute range && (range.Min is double || range.Min is float))
{
Contract.Assert(type != typeof(int) && type != typeof(short) && type != typeof(long), $"[Option:{optionName}] UniformDoubleOption or UniformSingleOption can't apply to property or field which type is int or short or long");
return;
diff --git a/src/Microsoft.ML.StandardTrainers/Microsoft.ML.StandardTrainers.csproj b/src/Microsoft.ML.StandardTrainers/Microsoft.ML.StandardTrainers.csproj
index d3a7a06c3c..f3b20954de 100644
--- a/src/Microsoft.ML.StandardTrainers/Microsoft.ML.StandardTrainers.csproj
+++ b/src/Microsoft.ML.StandardTrainers/Microsoft.ML.StandardTrainers.csproj
@@ -10,10 +10,6 @@
-
- all
- true
-
diff --git a/src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertEncoder.cs b/src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertEncoder.cs
index a486db2278..ce9b2a192a 100644
--- a/src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertEncoder.cs
+++ b/src/Microsoft.ML.TorchSharp/NasBert/Models/NasBertEncoder.cs
@@ -12,6 +12,7 @@
using Microsoft.ML.TorchSharp.Utils;
using TorchSharp;
using TorchSharp.Modules;
+using static Microsoft.ML.TorchSharp.NasBert.Modules.SearchSpace;
namespace Microsoft.ML.TorchSharp.NasBert.Models
{
@@ -255,13 +256,13 @@ private List CheckBlockHiddenSize(int blockPerLayer)
for (var i = 0; i < DistillBlocks; ++i)
{
var hiddenSizesPerBlock = Enumerable.Range(i * blockPerLayer, blockPerLayer)
- .Select(j => SearchSpace.ArchHiddenSize[DiscreteArches[j]]).ToArray();
- var nextHiddenSize = SearchSpace.CheckHiddenDimensionsAndReturnMax(hiddenSizesPerBlock);
+ .Select(j => ArchHiddenSize[DiscreteArches[j]]).ToArray();
+ var nextHiddenSize = CheckHiddenDimensionsAndReturnMax(hiddenSizesPerBlock);
if (nextHiddenSize == 0)
{
if (hiddenSizePerBlock.Count == 0)
{
- nextHiddenSize = SearchSpace.ArchHiddenSize[^1];
+ nextHiddenSize = ArchHiddenSize[^1];
}
else
{