Skip to content
Merged
22 changes: 11 additions & 11 deletions src/TorchSharp/Optimizers/Rprop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public class Rprop : OptimizerHelper
/// <param name="min_step">Minimum allowed step size.</param>
/// <param name="max_step">Maximum allowed step size.</param>
/// <param name="maximize">Maximize the params based on the objective, instead of minimizing.</param>
public Rprop(IEnumerable<Parameter> parameters, double lr = 0.01, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
public Rprop(IEnumerable<Parameter> parameters, double lr = 1e-2, double etaminus = 0.5, double etaplus = 1.2, double min_step = 1e-6, double max_step = 50, bool maximize = false)
Comment thread
alinpahontu2912 marked this conversation as resolved.
: this(new ParamGroup[] { new() { Parameters = parameters } }, lr, etaminus, etaplus, min_step, max_step, maximize)
{
}
Expand Down Expand Up @@ -156,20 +156,20 @@ public override Tensor step(Func<Tensor> closure = null)

state.step += 1;

grad = (max_step != 0)
? grad.add(param, alpha: max_step)
: grad.alias();
var grad_prod = grad.mul(state.prev);
Comment thread
alinpahontu2912 marked this conversation as resolved.
Outdated
Comment thread
alinpahontu2912 marked this conversation as resolved.
Outdated

var sign = grad.mul(state.prev).sign();
sign[sign.gt(0)] = (Tensor)etaplus;
sign[sign.lt(0)] = (Tensor)etaminus;
sign[sign.eq(0)] = (Tensor)1;
var pos_mask = grad_prod.gt(0);
var neg_mask = grad_prod.lt(0);
var zero_mask = grad_prod.eq(0);

state.step_size.mul_(sign).clamp_(min_step, max_step);
var step_size_update = pos_mask.to(torch.float64) * etaplus +
neg_mask.to(torch.float64) * etaminus +
zero_mask.to(torch.float64);

grad = grad.clone();
state.step_size.mul_(step_size_update).clamp_(min_step, max_step);

grad.index_put_(0, sign.eq(etaminus));
grad = grad.clone();
grad.masked_fill_(neg_mask, 0);

param.addcmul_(grad.sign(), state.step_size, -1);

Expand Down
12 changes: 6 additions & 6 deletions test/TorchSharpTest/TestTraining.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ public void TrainingRprop()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(229.68f, loss);
LossIsClose(77.279f, loss);
}


Expand All @@ -1187,7 +1187,7 @@ public void TrainingRpropMax()

var loss = TrainLoop(seq, x, y, optimizer, maximize:true);

LossIsClose(229.68f, -loss);
LossIsClose(77.279f, -loss);
}

[Fact]
Expand All @@ -1202,8 +1202,8 @@ public void TrainingRpropEtam()
var optimizer = torch.optim.Rprop(seq.parameters(), etaminus: 0.55);

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(201.417f, loss);
Comment thread
alinpahontu2912 marked this conversation as resolved.
Outdated
LossIsClose(171.12f, loss);
}

[Fact]
Expand All @@ -1219,7 +1219,7 @@ public void TrainingRpropEtap()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(221.365f, loss);
LossIsClose(65.859f, loss);
}


Expand All @@ -1240,7 +1240,7 @@ public void TrainingRpropParamGroups()

var loss = TrainLoop(seq, x, y, optimizer);

LossIsClose(78.619f, loss);
LossIsClose(66.479f, loss);
}

/// <summary>
Expand Down