From 588aa5a486d5bc70ded31f4b84d16ace9ccb9cd6 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 16 May 2024 22:55:58 -0700 Subject: [PATCH] fix: identity matrix different dtype than output --- scripts/pag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/pag.py b/scripts/pag.py index 9bc310e..37f3433 100644 --- a/scripts/pag.py +++ b/scripts/pag.py @@ -354,12 +354,12 @@ def pag_pre_hook(module, input, kwargs, output): return batch_size, seq_len, inner_dim = output.shape - identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(shared.device) + last_to_v = getattr(module, 'pag_last_to_v', None) # get the last to_v output and save it - last_to_v = getattr(module, 'pag_last_to_v', None) if last_to_v is not None: - new_output = torch.einsum('bij,bjk->bik', identity, last_to_v) + identity = torch.eye(seq_len).expand(batch_size, -1, -1).to(device=shared.device, dtype=last_to_v.dtype) + new_output = torch.einsum('bij,bjk->bik', identity, last_to_v).to(dtype=output.dtype) return new_output else: # this is bad