|
| 1 | +import torch |
| 2 | +from torch.optim import Optimizer |
| 3 | + |
| 4 | + |
| 5 | +class RMSpropTFLike(Optimizer): |
| 6 | + r"""Implements RMSprop algorithm with closer match to Tensorflow version. |
| 7 | +
|
| 8 | + For reproducibility with original stable-baselines. Use this |
| 9 | + version with e.g. A2C for stabler learning than with the PyTorch |
| 10 | + RMSProp. Based on the PyTorch v1.5.0 implementation of RMSprop. |
| 11 | +
|
| 12 | + See a more throughout conversion in pytorch-image-models repository: |
| 13 | + https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/rmsprop_tf.py |
| 14 | +
|
| 15 | + Changes to the original RMSprop: |
| 16 | + - Move epsilon inside square root |
| 17 | + - Initialize squared gradient to ones rather than zeros |
| 18 | +
|
| 19 | + Proposed by G. Hinton in his |
| 20 | + `course <http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_. |
| 21 | +
|
| 22 | + The centered version first appears in `Generating Sequences |
| 23 | + With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_. |
| 24 | +
|
| 25 | + The implementation here takes the square root of the gradient average before |
| 26 | + adding epsilon (note that TensorFlow interchanges these two operations). The effective |
| 27 | + learning rate is thus :math:`\alpha/(\sqrt{v} + \epsilon)` where :math:`\alpha` |
| 28 | + is the scheduled learning rate and :math:`v` is the weighted moving average |
| 29 | + of the squared gradient. |
| 30 | +
|
| 31 | + Arguments: |
| 32 | + params (iterable): iterable of parameters to optimize or dicts defining |
| 33 | + parameter groups |
| 34 | + lr (float, optional): learning rate (default: 1e-2) |
| 35 | + momentum (float, optional): momentum factor (default: 0) |
| 36 | + alpha (float, optional): smoothing constant (default: 0.99) |
| 37 | + eps (float, optional): term added to the denominator to improve |
| 38 | + numerical stability (default: 1e-8) |
| 39 | + centered (bool, optional) : if ``True``, compute the centered RMSProp, |
| 40 | + the gradient is normalized by an estimation of its variance |
| 41 | + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
| 42 | +
|
| 43 | + """ |
| 44 | + |
| 45 | + def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False): |
| 46 | + if not 0.0 <= lr: |
| 47 | + raise ValueError("Invalid learning rate: {}".format(lr)) |
| 48 | + if not 0.0 <= eps: |
| 49 | + raise ValueError("Invalid epsilon value: {}".format(eps)) |
| 50 | + if not 0.0 <= momentum: |
| 51 | + raise ValueError("Invalid momentum value: {}".format(momentum)) |
| 52 | + if not 0.0 <= weight_decay: |
| 53 | + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) |
| 54 | + if not 0.0 <= alpha: |
| 55 | + raise ValueError("Invalid alpha value: {}".format(alpha)) |
| 56 | + |
| 57 | + defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay) |
| 58 | + super(RMSpropTFLike, self).__init__(params, defaults) |
| 59 | + |
| 60 | + def __setstate__(self, state): |
| 61 | + super(RMSpropTFLike, self).__setstate__(state) |
| 62 | + for group in self.param_groups: |
| 63 | + group.setdefault("momentum", 0) |
| 64 | + group.setdefault("centered", False) |
| 65 | + |
| 66 | + @torch.no_grad() |
| 67 | + def step(self, closure=None): |
| 68 | + """Performs a single optimization step. |
| 69 | +
|
| 70 | + Arguments: |
| 71 | + closure (callable, optional): A closure that reevaluates the model |
| 72 | + and returns the loss. |
| 73 | + """ |
| 74 | + loss = None |
| 75 | + if closure is not None: |
| 76 | + with torch.enable_grad(): |
| 77 | + loss = closure() |
| 78 | + |
| 79 | + for group in self.param_groups: |
| 80 | + for p in group["params"]: |
| 81 | + if p.grad is None: |
| 82 | + continue |
| 83 | + grad = p.grad |
| 84 | + if grad.is_sparse: |
| 85 | + raise RuntimeError("RMSpropTF does not support sparse gradients") |
| 86 | + state = self.state[p] |
| 87 | + |
| 88 | + # State initialization |
| 89 | + if len(state) == 0: |
| 90 | + state["step"] = 0 |
| 91 | + # PyTorch initialized to zeros here |
| 92 | + state["square_avg"] = torch.ones_like(p, memory_format=torch.preserve_format) |
| 93 | + if group["momentum"] > 0: |
| 94 | + state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| 95 | + if group["centered"]: |
| 96 | + state["grad_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) |
| 97 | + |
| 98 | + square_avg = state["square_avg"] |
| 99 | + alpha = group["alpha"] |
| 100 | + |
| 101 | + state["step"] += 1 |
| 102 | + |
| 103 | + if group["weight_decay"] != 0: |
| 104 | + grad = grad.add(p, alpha=group["weight_decay"]) |
| 105 | + |
| 106 | + square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha) |
| 107 | + |
| 108 | + if group["centered"]: |
| 109 | + grad_avg = state["grad_avg"] |
| 110 | + grad_avg.mul_(alpha).add_(grad, alpha=1 - alpha) |
| 111 | + # PyTorch added epsilon after square root |
| 112 | + # avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_().add_(group['eps']) |
| 113 | + avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).add_(group["eps"]).sqrt_() |
| 114 | + else: |
| 115 | + # PyTorch added epsilon after square root |
| 116 | + # avg = square_avg.sqrt().add_(group['eps']) |
| 117 | + avg = square_avg.add(group["eps"]).sqrt_() |
| 118 | + |
| 119 | + if group["momentum"] > 0: |
| 120 | + buf = state["momentum_buffer"] |
| 121 | + buf.mul_(group["momentum"]).addcdiv_(grad, avg) |
| 122 | + p.add_(buf, alpha=-group["lr"]) |
| 123 | + else: |
| 124 | + p.addcdiv_(grad, avg, value=-group["lr"]) |
| 125 | + |
| 126 | + return loss |
0 commit comments