-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathGradient_Reversal_Layer.py
35 lines (30 loc) · 996 Bytes
/
Gradient_Reversal_Layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Original referece: https://github.com/janfreyberg/pytorch-revgrad
# Only weight paramter is added.
import torch
class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, weight):
ctx.save_for_backward(input_)
ctx.weight = weight
output = input_
return output
@staticmethod
def backward(ctx, grad_output): # pragma: no cover
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = -ctx.weight * grad_output
return grad_input, None
class GRL(torch.nn.Module):
def __init__(self, weight= 1.0):
"""
A gradient reversal layer.
This layer has no parameters, and simply reverses the gradient
in the backward pass.
"""
super(GRL, self).__init__()
self.weight = weight
def forward(self, input_):
return Func.apply(
input_,
torch.FloatTensor([self.weight]).to(device= input_.device)
)