Skip to content

Commit

Permalink
Deterministic actions python training (#5619)
Browse files Browse the repository at this point in the history
* Progress on propagating the setting to the action model.

* Added the _sample_action logic and tests.

* Add information to the changelog.

* Prioritize the CLI over the configuration file.

* Update documentation for config file.

* CR refactor.

* Update docs/Training-Configuration-File.md

Co-authored-by: Miguel Alonso Jr. <[email protected]>
Update com.unity.ml-agents/CHANGELOG.md

Co-authored-by: Miguel Alonso Jr. <[email protected]>
Update com.unity.ml-agents/CHANGELOG.md

Co-authored-by: Miguel Alonso Jr. <[email protected]>
Update com.unity.ml-agents/CHANGELOG.md

Co-authored-by: Maryam Honari <[email protected]>
Update ml-agents/mlagents/trainers/settings.py

Co-authored-by: Maryam Honari <[email protected]>
Update ml-agents/mlagents/trainers/cli_utils.py

Co-authored-by: Maryam Honari <[email protected]>

* Fix CR requests

* Add tests for discrete.

* Update ml-agents/mlagents/trainers/torch/distributions.py

Co-authored-by: Maryam Honari <[email protected]>

* Added more stable test.

* Return deterministic actions for training (#5615)

* Added more stable test.

* Fix the tests.

* Fix pre-commit

* Fix help line to pass precommit.

* support for deterministic inference in onnx (#5593)

* Init: actor.forward outputs separate deterministic actions

* changelog

* Renaming

* Add more tests

* Package changes to support deterministic inference (#5599)

* Init: actor.forward outputs separate deterministic actions

* fix tensor shape for discrete actions

* Add test and editor flag

- Add tests for deterministic sampling
- update editor and tooltips

* Reverting to "Deterministic Inference"

* dissect tests

* Update docs

* Update CHANGELOG.md

Co-authored-by: Chingiz Mardanov <[email protected]>
Co-authored-by: cmard <[email protected]>
  • Loading branch information
3 people authored Nov 18, 2021
1 parent 348bc9d commit 0de327c
Show file tree
Hide file tree
Showing 29 changed files with 469 additions and 66 deletions.
6 changes: 6 additions & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ and this project adheres to
2. env_params.restarts_rate_limit_n (--restarts-rate-limit-n) [default=1]
3. env_params.restarts_rate_limit_period_s (--restarts-rate-limit-period-s) [default=60]


- Deterministic action selection is now supported during training and inference(#5619)
- Added a new `--deterministic` cli flag to deterministically select the most probable actions in policy. The same thing can
be achieved by adding `deterministic: true` under `network_settings` of the run options configuration.(#5597)
- Extra tensors are now serialized to support deterministic action selection in onnx. (#5593)
- Support inference with deterministic action selection in editor (#5599)
### Bug Fixes
- Fixed a bug where the critics were not being normalized during training. (#5595)
- Fixed the bug where curriculum learning would crash because of the incorrect run_options parsing. (#5586)
Expand Down
4 changes: 3 additions & 1 deletion com.unity.ml-agents/Editor/BehaviorParametersEditor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ internal class BehaviorParametersEditor : UnityEditor.Editor
const string k_BrainParametersName = "m_BrainParameters";
const string k_ModelName = "m_Model";
const string k_InferenceDeviceName = "m_InferenceDevice";
const string k_DeterministicInference = "m_DeterministicInference";
const string k_BehaviorTypeName = "m_BehaviorType";
const string k_TeamIdName = "TeamId";
const string k_UseChildSensorsName = "m_UseChildSensors";
Expand Down Expand Up @@ -68,6 +69,7 @@ public override void OnInspectorGUI()
EditorGUILayout.PropertyField(so.FindProperty(k_ModelName), true);
EditorGUI.indentLevel++;
EditorGUILayout.PropertyField(so.FindProperty(k_InferenceDeviceName), true);
EditorGUILayout.PropertyField(so.FindProperty(k_DeterministicInference), true);
EditorGUI.indentLevel--;
}
needPolicyUpdate = needPolicyUpdate || EditorGUI.EndChangeCheck();
Expand Down Expand Up @@ -156,7 +158,7 @@ void DisplayFailedModelChecks()
{
var failedChecks = Inference.BarracudaModelParamLoader.CheckModel(
barracudaModel, brainParameters, sensors, actuatorComponents,
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType
observableAttributeSensorTotalSize, behaviorParameters.BehaviorType, behaviorParameters.DeterministicInference
);
foreach (var check in failedChecks)
{
Expand Down
6 changes: 4 additions & 2 deletions com.unity.ml-agents/Runtime/Academy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,16 @@ void EnvironmentReset()
/// <param name="inferenceDevice">
/// The inference device (CPU or GPU) the ModelRunner will use.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// Deterministic. </param>
/// <returns> The ModelRunner compatible with the input settings.</returns>
internal ModelRunner GetOrCreateModelRunner(
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice)
NNModel model, ActionSpec actionSpec, InferenceDevice inferenceDevice, bool deterministicInference = false)
{
var modelRunner = m_ModelRunners.Find(x => x.HasModel(model, inferenceDevice));
if (modelRunner == null)
{
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed);
modelRunner = new ModelRunner(model, actionSpec, inferenceDevice, m_InferenceSeed, deterministicInference);
m_ModelRunners.Add(modelRunner);
m_InferenceSeed++;
}
Expand Down
107 changes: 80 additions & 27 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,10 @@ public static int GetNumVisualInputs(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Array of the output tensor names of the model</returns>
public static string[] GetOutputNames(this Model model)
public static string[] GetOutputNames(this Model model, bool deterministicInference = false)
{
var names = new List<string>();

Expand All @@ -122,13 +124,13 @@ public static string[] GetOutputNames(this Model model)
return names.ToArray();
}

if (model.HasContinuousOutputs())
if (model.HasContinuousOutputs(deterministicInference))
{
names.Add(model.ContinuousOutputName());
names.Add(model.ContinuousOutputName(deterministicInference));
}
if (model.HasDiscreteOutputs())
if (model.HasDiscreteOutputs(deterministicInference))
{
names.Add(model.DiscreteOutputName());
names.Add(model.DiscreteOutputName(deterministicInference));
}

var modelVersion = model.GetVersion();
Expand All @@ -149,8 +151,10 @@ public static string[] GetOutputNames(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model has continuous action outputs.</returns>
public static bool HasContinuousOutputs(this Model model)
public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
Expand All @@ -160,8 +164,13 @@ public static bool HasContinuousOutputs(this Model model)
}
else
{
return model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
bool hasStochasticOutput = !deterministicInference &&
model.outputs.Contains(TensorNames.ContinuousActionOutput);
bool hasDeterministicOutput = deterministicInference &&
model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput);

return (hasStochasticOutput || hasDeterministicOutput) &&
(int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0;
}
}

Expand Down Expand Up @@ -194,8 +203,10 @@ public static int ContinuousOutputSize(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Tensor name of continuous action output.</returns>
public static string ContinuousOutputName(this Model model)
public static string ContinuousOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
Expand All @@ -205,7 +216,7 @@ public static string ContinuousOutputName(this Model model)
}
else
{
return TensorNames.ContinuousActionOutput;
return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput;
}
}

Expand All @@ -215,8 +226,10 @@ public static string ContinuousOutputName(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model has discrete action outputs.</returns>
public static bool HasDiscreteOutputs(this Model model)
public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false)
{
if (model == null)
return false;
Expand All @@ -226,7 +239,12 @@ public static bool HasDiscreteOutputs(this Model model)
}
else
{
return model.outputs.Contains(TensorNames.DiscreteActionOutput) && model.DiscreteOutputSize() > 0;
bool hasStochasticOutput = !deterministicInference &&
model.outputs.Contains(TensorNames.DiscreteActionOutput);
bool hasDeterministicOutput = deterministicInference &&
model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput);
return (hasStochasticOutput || hasDeterministicOutput) &&
model.DiscreteOutputSize() > 0;
}
}

Expand Down Expand Up @@ -279,8 +297,10 @@ public static int DiscreteOutputSize(this Model model)
/// <param name="model">
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>Tensor name of discrete action output.</returns>
public static string DiscreteOutputName(this Model model)
public static string DiscreteOutputName(this Model model, bool deterministicInference = false)
{
if (model == null)
return null;
Expand All @@ -290,7 +310,7 @@ public static string DiscreteOutputName(this Model model)
}
else
{
return TensorNames.DiscreteActionOutput;
return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput;
}
}

Expand All @@ -316,9 +336,11 @@ public static bool SupportsContinuousAndDiscrete(this Model model)
/// The Barracuda engine model for loading static parameters.
/// </param>
/// <param name="failedModelChecks">Output list of failure messages</param>
///
///<param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>True if the model contains all the expected tensors.</returns>
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks)
/// TODO: add checks for deterministic actions
public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false)
{
// Check the presence of model version
var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber);
Expand All @@ -343,7 +365,9 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
// Check the presence of action output tensor
if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) &&
!model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DiscreteActionOutput))
!model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
!model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) &&
!model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model does not contain any Action Output Node.")
Expand Down Expand Up @@ -373,22 +397,51 @@ public static bool CheckExpectedTensors(this Model model, List<FailedCheck> fail
}
else
{
if (model.outputs.Contains(TensorNames.ContinuousActionOutput) &&
model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
if (model.outputs.Contains(TensorNames.ContinuousActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.")
);
return false;
return false;
}

else if (!model.HasContinuousOutputs(deterministicInference))
{
var actionType = deterministicInference ? "deterministic" : "stochastic";
var actionName = deterministicInference ? "Deterministic" : "";
failedModelChecks.Add(
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..")
);
return false;
}
}
if (model.outputs.Contains(TensorNames.DiscreteActionOutput) &&
model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)

if (model.outputs.Contains(TensorNames.DiscreteActionOutput))
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null)
{
failedModelChecks.Add(
FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.")
);
return false;
return false;
}
else if (!model.HasDiscreteOutputs(deterministicInference))
{
var actionType = deterministicInference ? "deterministic" : "stochastic";
var actionName = deterministicInference ? "Deterministic" : "";
failedModelChecks.Add(
FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.")
);
return false;
}

}




}
return true;
}
Expand Down
24 changes: 16 additions & 8 deletions com.unity.ml-agents/Runtime/Inference/BarracudaModelParamLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,17 @@ public static FailedCheck CheckModelVersion(Model model)
/// <param name="actuatorComponents">Attached actuator components</param>
/// <param name="observableAttributeTotalSize">Sum of the sizes of all ObservableAttributes.</param>
/// <param name="behaviorType">BehaviorType or the Agent to check.</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>A IEnumerable of the checks that failed</returns>
public static IEnumerable<FailedCheck> CheckModel(
Model model,
BrainParameters brainParameters,
ISensor[] sensors,
ActuatorComponent[] actuatorComponents,
int observableAttributeTotalSize = 0,
BehaviorType behaviorType = BehaviorType.Default
BehaviorType behaviorType = BehaviorType.Default,
bool deterministicInference = false
)
{
List<FailedCheck> failedModelChecks = new List<FailedCheck>();
Expand All @@ -148,7 +151,7 @@ public static IEnumerable<FailedCheck> CheckModel(
return failedModelChecks;
}

var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks);
var hasExpectedTensors = model.CheckExpectedTensors(failedModelChecks, deterministicInference);
if (!hasExpectedTensors)
{
return failedModelChecks;
Expand Down Expand Up @@ -181,7 +184,7 @@ public static IEnumerable<FailedCheck> CheckModel(
else if (modelApiVersion == (int)ModelApiVersion.MLAgents2_0)
{
failedModelChecks.AddRange(
CheckInputTensorPresence(model, brainParameters, memorySize, sensors)
CheckInputTensorPresence(model, brainParameters, memorySize, sensors, deterministicInference)
);
failedModelChecks.AddRange(
CheckInputTensorShape(model, brainParameters, sensors, observableAttributeTotalSize)
Expand All @@ -195,7 +198,7 @@ public static IEnumerable<FailedCheck> CheckModel(
);

failedModelChecks.AddRange(
CheckOutputTensorPresence(model, memorySize)
CheckOutputTensorPresence(model, memorySize, deterministicInference)
);
return failedModelChecks;
}
Expand Down Expand Up @@ -318,14 +321,17 @@ ISensor[] sensors
/// The memory size that the model is expecting.
/// </param>
/// <param name="sensors">Array of attached sensor components</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// Deterministic. </param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckInputTensorPresence(
Model model,
BrainParameters brainParameters,
int memory,
ISensor[] sensors
ISensor[] sensors,
bool deterministicInference = false
)
{
var failedModelChecks = new List<FailedCheck>();
Expand Down Expand Up @@ -356,7 +362,7 @@ ISensor[] sensors
}

// If the model uses discrete control but does not have an input for action masks
if (model.HasDiscreteOutputs())
if (model.HasDiscreteOutputs(deterministicInference))
{
if (!tensorsNames.Contains(TensorNames.ActionMaskPlaceholder))
{
Expand All @@ -376,17 +382,19 @@ ISensor[] sensors
/// The Barracuda engine model for loading static parameters
/// </param>
/// <param name="memory">The memory size that the model is expecting/</param>
/// <param name="deterministicInference"> Inference only: set to true if the action selection from model should be
/// deterministic. </param>
/// <returns>
/// A IEnumerable of the checks that failed
/// </returns>
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory)
static IEnumerable<FailedCheck> CheckOutputTensorPresence(Model model, int memory, bool deterministicInference = false)
{
var failedModelChecks = new List<FailedCheck>();

// If there is no Recurrent Output but the model is Recurrent.
if (memory > 0)
{
var allOutputs = model.GetOutputNames().ToList();
var allOutputs = model.GetOutputNames(deterministicInference).ToList();
if (!allOutputs.Any(x => x == TensorNames.RecurrentOutput))
{
failedModelChecks.Add(
Expand Down
Loading

0 comments on commit 0de327c

Please sign in to comment.