Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Gemma 2 models #486

Closed
3 tasks done
EricLBuehler opened this issue Jun 28, 2024 · 1 comment
Closed
3 tasks done

Implement Gemma 2 models #486

EricLBuehler opened this issue Jun 28, 2024 · 1 comment
Labels
models Additions to model or architectures

Comments

@EricLBuehler
Copy link
Owner

EricLBuehler commented Jun 28, 2024

Need to modify Gemma model implementation with:

Changelist over original Gemma and status:

  • Sliding window attn - for layers that satisfy idx % 2 != 0, so every other, will use sliding window
    • Affects KV cache retrieval
    • Affects sliding window mask generation
  • Logit soft capping
    • In attention, between Q*K^T * s and matmul(V)
     if self.config.attn_logit_softcapping is not None:
              attn_weights = attn_weights / self.config.attn_logit_softcapping
              attn_weights = torch.tanh(attn_weights)
              attn_weights = attn_weights * self.config.attn_logit_softcapping
    • After lm head
    if self.config.final_logit_softcapping is not None:
              logits = logits / self.config.final_logit_softcapping
              logits = torch.tanh(logits)
              logits = logits * self.config.final_logit_softcapping
  • Use query_pre_attn_scalar instead of 1/sqrt(head_dim)

Links

@EricLBuehler EricLBuehler added the models Additions to model or architectures label Jun 28, 2024
This was referenced Jun 29, 2024
@EricLBuehler
Copy link
Owner Author

Implemented in #490.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
models Additions to model or architectures
Projects
None yet
Development

No branches or pull requests

1 participant