diff --git a/src/TorchSharp/Optimizers/Rprop.cs b/src/TorchSharp/Optimizers/Rprop.cs
index 47e01d982..abe9d736e 100644
--- a/src/TorchSharp/Optimizers/Rprop.cs
+++ b/src/TorchSharp/Optimizers/Rprop.cs
@@ -86,7 +86,7 @@ public class Rprop : OptimizerHelper
/// Minimum allowed step size.
/// Maximum allowed step size.
/// Maximize the params based on the objective, instead of minimizing.
- public Rprop(IEnumerable parameters, double lr = 0.01, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
+ public Rprop(IEnumerable parameters, double lr = 1e-2, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
: this(new ParamGroup[] { new() { Parameters = parameters } }, lr, etaminus, etaplus, min_step, max_step, maximize)
{
}
@@ -156,10 +156,6 @@ public override Tensor step(Func closure = null)
state.step += 1;
- grad = (max_step != 0)
- ? grad.add(param, alpha: max_step)
- : grad.alias();
-
var sign = grad.mul(state.prev).sign();
sign[sign.gt(0)] = (Tensor)etaplus;
sign[sign.lt(0)] = (Tensor)etaminus;
diff --git a/test/TorchSharpTest/TestTraining.cs b/test/TorchSharpTest/TestTraining.cs
index b1ec9d4a2..d73c5c0e6 100644
--- a/test/TorchSharpTest/TestTraining.cs
+++ b/test/TorchSharpTest/TestTraining.cs
@@ -1170,7 +1170,7 @@ public void TrainingRprop()
var loss = TrainLoop(seq, x, y, optimizer);
- LossIsClose(229.68f, loss);
+ LossIsClose(77.279f, loss);
}
@@ -1187,7 +1187,7 @@ public void TrainingRpropMax()
var loss = TrainLoop(seq, x, y, optimizer, maximize:true);
- LossIsClose(229.68f, -loss);
+ LossIsClose(77.279f, -loss);
}
[Fact]
@@ -1203,7 +1203,7 @@ public void TrainingRpropEtam()
var loss = TrainLoop(seq, x, y, optimizer);
- LossIsClose(201.417f, loss);
+ LossIsClose(171.12f, loss);
}
[Fact]
@@ -1219,7 +1219,7 @@ public void TrainingRpropEtap()
var loss = TrainLoop(seq, x, y, optimizer);
- LossIsClose(221.365f, loss);
+ LossIsClose(65.859f, loss);
}
@@ -1240,7 +1240,7 @@ public void TrainingRpropParamGroups()
var loss = TrainLoop(seq, x, y, optimizer);
- LossIsClose(78.619f, loss);
+ LossIsClose(66.479f, loss);
}
///