-
-
Notifications
You must be signed in to change notification settings - Fork 450
Customizable Embedder and Logit Mapper #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Customizable Embedder and Logit Mapper #288
Conversation
@pradeep-pyro this is cool, thank you! do you think it is possible to just keep the default as also, perhaps we can just have |
Thanks for the feedback!
This solution is perfect, let me refactor. With this change, I think I can keep using |
This reverts commit e8166b4.
…e_embedder_and_logit_mapper
@lucidrains I modified the code like you suggested. It is certainly more cleaner this way. Please take another look when you get the chance. Thanks! |
@pradeep-pyro looks great Pradeep! graciously accept! 🙏 |
Hi @lucidrains, I use x-transformers extensively and often run into the need to use custom embedder and logit mapper classes together with the
AutoregressiveWrapper
class, for e.g., to implement pointer networks (https://arxiv.org/abs/1506.03134). Currently this is not possible, so I work with the lower levelDecoder
class and implement the wrapper myself.In this PR, I tried to enable support for this by slightly reorganizing the
TokenEmbedding
class, and adding a class calledLogitMapper
to be used as theTransformerWrapper.to_logits
layer. The idea is that we can implement custom classes conforming to these interfaces and everything should work seamlessly. These classes accept kwargs in their forward method that allows us to pass additional tensors to them.The unit tests all pass and I added some additional ones. I hope I didn't break anything! Let me know if you're ok to merge this change, or if you have better ways of achieving this behavior. Thank you.