From 7b23a0a5d2d07c4b67d81e2efb6020a834bc4f2f Mon Sep 17 00:00:00 2001 From: Lucas Nestler <39779310+ClashLuke@users.noreply.github.com> Date: Sun, 26 Sep 2021 16:50:01 +0200 Subject: [PATCH] perf(AconC): replicate shared values --- rgn2_replica/rgn2.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/rgn2_replica/rgn2.py b/rgn2_replica/rgn2.py index 55e3123..a0a5752 100644 --- a/rgn2_replica/rgn2.py +++ b/rgn2_replica/rgn2.py @@ -139,12 +139,15 @@ def __init__(self, width): @torch.jit.script_method def forward(self, x): """ Inputs (B, L, C) --> Outputs (B, L, C). """ - p1, p2, beta = self.p1, self.p2, self.beta - while x.dim() > p1.dim(): - p1 = p1.unsqueeze(0) - p2 = p2.unsqueeze(0) - beta = beta.unsqueeze(0) - return (p1 * x - p2 * x) * torch.sigmoid(beta * (p1 * x - p2 * x)) + p2 * x + # TODO: Inspect how much beta changes. If it stays close to 1, use F.swish instead. + # TODO: (p1 - p2) * x shouldn't do much more than just p1 * x + # TODO: Considering that PReLU overfits, check if F.swish(x) + x is better + shape = (1,) * (x.ndim() - 1) + (self.p1.size(0),) + p1 = self.p1.view(*shape) + p2 = self.p2.view(*shape) + beta = self.beta.view(*shape) + x_mul = (p1 - p2) * x + return x_mul * x_mul.mul(beta).sigmoid() + p2 * x # from https://github.com/FlorianWilhelm/mlstm4reco/blob/master/src/mlstm4reco/layers.py