Conversation
| ```python | ||
| pre_expansion_embeddings = model.language_model.lm_head.weight.data | ||
| mu = torch.mean(pre_expansion_embeddings, dim=0).float() | ||
| n = pre_expansion_embeddings.size()[0] | ||
| sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n | ||
| dist = torch.distributions.multivariate_normal.MultivariateNormal(mu, covariance_matrix=1e-5 * sigma) | ||
|
|
||
|
|
||
| num_new_tokens = 1 # 1 for the `"<|image|>"` token | ||
| lm_head_weights = model.language_model.lm_head.weight | ||
|
|
||
| new_token_embedding = torch.stack(tuple(dist.sample() for _ in range(num_new_tokens)), dim=0).to(device=lm_head_weights.device, dtype=lm_head_weights.dtype) | ||
| lm_head_weights.data = torch.cat([lm_head_weights.data, new_token_embedding], dim=0) | ||
| lm_head_weights.num_embeddings = lm_head_weights.data.shape[0] | ||
| ``` |
There was a problem hiding this comment.
This should already be done internally if you use the correct flag for resize token embedding
There was a problem hiding this comment.
from what I see we don;t let users to specify which embeddings to resize and use input_embeddings by default. In case weights are tied (not case of mllama) we also resize output embeddings
Or you mean there is another method similar to resize_token_embeddings? Might have overlooked that
There was a problem hiding this comment.
We use the output of get_input_embeddings which should always be the input embedding and by default the output embedding from get_output_emebdding is resized when you are tied. But you are right, you can't only resize the lm head.
Tho ther might be some util function you can re-use no? 🤗 Feel free to merge!
There was a problem hiding this comment.
right, there was a way to hide all the ugly code in private methods
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* update docs * be more explicit * use avaialble methods
What does this PR do?
Fixes #34304 and adds info about lm-head resizing. Maybe also fixes #33819?