Skip to content

Conversation

BrownianNotion
Copy link

Description

Hi I'm new here, feedback is welcome/let me know if I've missed anything!

Due to a bug in PyTorch 2.8.0 F.linear for mps pytorch/pytorch#161640, the lines below

out = F.linear(
                    z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
                    w,
                    self.b_O,
                )

from https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/abstract_attention.py#L302-L306
produce incorrect attention outputs on mps. Cpu works fine. I'm on Mac Sequoia 15.6.1.

I've added a unit test which reproduces the issue in this commit f903629.

To fix this, I have replaced F.linear with einops.einsum which is also more consistent with the rest of the class.

Related issues: #1008 #1062

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

@BrownianNotion BrownianNotion changed the title Fix attn mps Fix attention calculation on mps Sep 27, 2025
@BrownianNotion BrownianNotion changed the title Fix attention calculation on mps Fix attention calculation on mps for torch 2.8.0 Sep 27, 2025
@BrownianNotion
Copy link
Author

Alternate simpler/less invasive fix - just add .contiguous() to the reshaped z and w before the F.linear calculation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant