diff --git a/src/TorchSharp/Optimizers/Rprop.cs b/src/TorchSharp/Optimizers/Rprop.cs index 47e01d982..abe9d736e 100644 --- a/src/TorchSharp/Optimizers/Rprop.cs +++ b/src/TorchSharp/Optimizers/Rprop.cs @@ -86,7 +86,7 @@ public class Rprop : OptimizerHelper /// Minimum allowed step size. /// Maximum allowed step size. /// Maximize the params based on the objective, instead of minimizing. - public Rprop(IEnumerable parameters, double lr = 0.01, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false) + public Rprop(IEnumerable parameters, double lr = 1e-2, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false) : this(new ParamGroup[] { new() { Parameters = parameters } }, lr, etaminus, etaplus, min_step, max_step, maximize) { } @@ -156,10 +156,6 @@ public override Tensor step(Func closure = null) state.step += 1; - grad = (max_step != 0) - ? grad.add(param, alpha: max_step) - : grad.alias(); - var sign = grad.mul(state.prev).sign(); sign[sign.gt(0)] = (Tensor)etaplus; sign[sign.lt(0)] = (Tensor)etaminus; diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs index b1ec9d4a2..d73c5c0e6 100644 --- a/test/TorchSharpTest/TestTraining.cs +++ b/test/TorchSharpTest/TestTraining.cs @@ -1170,7 +1170,7 @@ public void TrainingRprop() var loss = TrainLoop(seq, x, y, optimizer); - LossIsClose(229.68f, loss); + LossIsClose(77.279f, loss); } @@ -1187,7 +1187,7 @@ public void TrainingRpropMax() var loss = TrainLoop(seq, x, y, optimizer, maximize:true); - LossIsClose(229.68f, -loss); + LossIsClose(77.279f, -loss); } [Fact] @@ -1203,7 +1203,7 @@ public void TrainingRpropEtam() var loss = TrainLoop(seq, x, y, optimizer); - LossIsClose(201.417f, loss); + LossIsClose(171.12f, loss); } [Fact] @@ -1219,7 +1219,7 @@ public void TrainingRpropEtap() var loss = TrainLoop(seq, x, y, optimizer); - LossIsClose(221.365f, loss); + LossIsClose(65.859f, loss); } @@ -1240,7 +1240,7 @@ public void TrainingRpropParamGroups() var loss = TrainLoop(seq, x, y, optimizer); - LossIsClose(78.619f, loss); + LossIsClose(66.479f, loss); } ///