Skip to content

Commit

Permalink
reverting export
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi committed Dec 17, 2024
1 parent 0ea49e4 commit f75c7cd
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 17 deletions.
20 changes: 9 additions & 11 deletions torchtune/modules/_export/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,9 @@ def _load_state_dict_hook(
**kwargs (Dict[str, Any]): Additional keyword arguments.
Raises:
ValueError:
If the shape of the loaded embedding is not compatible with the current embedding, **or**
if max_num_tiles_x, max_num_tiles_y are not equal, **or**
if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
"""

embedding = state_dict.get(prefix + "embedding")
Expand Down Expand Up @@ -303,13 +302,12 @@ def _load_state_dict_hook(
**kwargs (Dict[str, Any]): Additional keyword arguments.
Raises:
ValueError:
If loaded local or global embedding n_tokens_per_tile is not derived
from a squared grid, **or**
if after interpolation, the shape of the loaded local embedding
is not compatible with the current embedding, **or**
if after interpolation, the shape of the loaded global embedding
is not compatible with the current embedding.
ValueError: if loaded local or global embedding n_tokens_per_tile is not derived
from a squared grid.
ValueError: if after interpolation, the shape of the loaded local embedding
is not compatible with the current embedding.
ValueError: if after interpolation, the shape of the loaded global embedding
is not compatible with the current embedding.
"""

# process local_token_positional_embedding
Expand Down
9 changes: 4 additions & 5 deletions torchtune/modules/_export/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,10 @@ class MultiHeadAttention(nn.Module):
Default value is 0.0.
Raises:
ValueError:
If ``num_heads % num_kv_heads != 0``, **or**
If ``embed_dim % num_heads != 0``, **or**
If ``attn_dropout < 0`` or ``attn_dropout > 1``, **or**
if q_norm is defined without k_norm or vice versa
ValueError: If ``num_heads % num_kv_heads != 0``
ValueError: If ``embed_dim % num_heads != 0``
ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
ValueError: if q_norm is defined without k_norm or vice versa
"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion torchtune/modules/_export/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def update(
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
Raises:
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. #noqa
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
used during cache setup.
"""
Expand Down

0 comments on commit f75c7cd

Please sign in to comment.