Skip to content

Customizable Embedder and Logit Mapper #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

pradeep-pyro
Copy link
Contributor

Hi @lucidrains, I use x-transformers extensively and often run into the need to use custom embedder and logit mapper classes together with the AutoregressiveWrapper class, for e.g., to implement pointer networks (https://arxiv.org/abs/1506.03134). Currently this is not possible, so I work with the lower level Decoder class and implement the wrapper myself.

In this PR, I tried to enable support for this by slightly reorganizing the TokenEmbedding class, and adding a class called LogitMapper to be used as the TransformerWrapper.to_logits layer. The idea is that we can implement custom classes conforming to these interfaces and everything should work seamlessly. These classes accept kwargs in their forward method that allows us to pass additional tensors to them.

The unit tests all pass and I added some additional ones. I hope I didn't break anything! Let me know if you're ok to merge this change, or if you have better ways of achieving this behavior. Thank you.

@lucidrains
Copy link
Owner

@pradeep-pyro this is cool, thank you!

do you think it is possible to just keep the default as LinearNoBias, without the LogitMapper indirection?

also, perhaps we can just have token_emb_kwargs: dict and to_logits_kwargs: dict in the forward method instead

@pradeep-pyro
Copy link
Contributor Author

Thanks for the feedback!

also, perhaps we can just have token_emb_kwargs: dict and to_logits_kwargs: dict in the forward method instead

This solution is perfect, let me refactor. With this change, I think I can keep using LinearNoBias.

@pradeep-pyro
Copy link
Contributor Author

@lucidrains I modified the code like you suggested. It is certainly more cleaner this way. Please take another look when you get the chance. Thanks!

@lucidrains
Copy link
Owner

@pradeep-pyro looks great Pradeep! graciously accept! 🙏

@lucidrains lucidrains merged commit 409ba0f into lucidrains:main Nov 7, 2024
1 check passed
@pradeep-pyro pradeep-pyro deleted the pradeep/customizable_embedder_and_logit_mapper branch November 7, 2024 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants