Skip to content

Conversation

@madisonmay
Copy link

@madisonmay madisonmay commented Mar 29, 2020

Still a work in progress but the contextual embeddings line up with the pytorch version so this is roughly at parity with jax-bert

TODO (mostly notes to myself):

  • Add save_pretrained
  • Make from_pretrained work with names
  • Add dropout at training time, pass through training flag
  • Make sure weight initializations line up when pre-trained state isn't passed
  • Gradually work towards parity with the pytorch version if desired? (target models, BERT variants, etc.)
  • Write HaikuPretrainedModel to take advantage of archive resolution / make saving + loading compatible with pytorch bins?

To use the pre-trained weights cleanly I ended up subclassing hk.Module -- unsure how I feel about this decision but I couldn't think of a better method at the time. Feel free to suggest an alternative if you have ideas.

@mfuntowicz mfuntowicz self-requested a review April 7, 2020 12:25
@stale
Copy link

stale bot commented Jul 27, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Jul 27, 2020
@madisonmay madisonmay closed this Jul 28, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants