Skip to content
Merged
6 changes: 1 addition & 5 deletions src/TorchSharp/Optimizers/Rprop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public class Rprop : OptimizerHelper
/// <param name="min_step">Minimum allowed step size.</param>
/// <param name="max_step">Maximum allowed step size.</param>
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing.</param>
public Rprop(IEnumerable<Parameter> parameters, double lr = 0.01, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
public Rprop(IEnumerable<Parameter> parameters, double lr = 1e-2, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
: this(new ParamGroup[] { new() { Parameters = parameters } }, lr, etaminus, etaplus, min_step, max_step, maximize)
{
}
Expand Down Expand Up @@ -156,10 +156,6 @@ public override Tensor step(Func<Tensor> closure = null)

state.step += 1;

grad = (max_step != 0)
? grad.add(param, alpha: max_step)
: grad.alias();

var sign = grad.mul(state.prev).sign();
sign[sign.gt(0)] = (Tensor)etaplus;
sign[sign.lt(0)] = (Tensor)etaminus;
Expand Down
10 changes: 5 additions & 5 deletions test/TorchSharpTest/TestTraining.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ public void TrainingRprop()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(229.68f, loss);
LossIsClose(77.279f, loss);
}


Expand All @@ -1187,7 +1187,7 @@ public void TrainingRpropMax()

var loss = TrainLoop(seq, x, y, optimizer, maximize:true);

LossIsClose(229.68f, -loss);
LossIsClose(77.279f, -loss);
}

[Fact]
Expand All @@ -1203,7 +1203,7 @@ public void TrainingRpropEtam()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(201.417f, loss);
LossIsClose(171.12f, loss);
}

[Fact]
Expand All @@ -1219,7 +1219,7 @@ public void TrainingRpropEtap()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(221.365f, loss);
LossIsClose(65.859f, loss);
}


Expand All @@ -1240,7 +1240,7 @@ public void TrainingRpropParamGroups()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(78.619f, loss);
LossIsClose(66.479f, loss);
}

/// <summary>
Expand Down