From f75c7cd15253fb9561aa46bc5b38d685d34208b7 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 17 Dec 2024 12:24:39 +0000 Subject: [PATCH] reverting export --- .../modules/_export/_position_embeddings.py | 20 +++++++++---------- torchtune/modules/_export/attention.py | 9 ++++----- torchtune/modules/_export/kv_cache.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/torchtune/modules/_export/_position_embeddings.py b/torchtune/modules/_export/_position_embeddings.py index bd4d14e516..0489b7f345 100644 --- a/torchtune/modules/_export/_position_embeddings.py +++ b/torchtune/modules/_export/_position_embeddings.py @@ -73,10 +73,9 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: - If the shape of the loaded embedding is not compatible with the current embedding, **or** - if max_num_tiles_x, max_num_tiles_y are not equal, **or** - if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. """ embedding = state_dict.get(prefix + "embedding") @@ -303,13 +302,12 @@ def _load_state_dict_hook( **kwargs (Dict[str, Any]): Additional keyword arguments. Raises: - ValueError: - If loaded local or global embedding n_tokens_per_tile is not derived - from a squared grid, **or** - if after interpolation, the shape of the loaded local embedding - is not compatible with the current embedding, **or** - if after interpolation, the shape of the loaded global embedding - is not compatible with the current embedding. + ValueError: if loaded local or global embedding n_tokens_per_tile is not derived + from a squared grid. + ValueError: if after interpolation, the shape of the loaded local embedding + is not compatible with the current embedding. + ValueError: if after interpolation, the shape of the loaded global embedding + is not compatible with the current embedding. """ # process local_token_positional_embedding diff --git a/torchtune/modules/_export/attention.py b/torchtune/modules/_export/attention.py index 352f97d9c1..bb3fe4a94b 100644 --- a/torchtune/modules/_export/attention.py +++ b/torchtune/modules/_export/attention.py @@ -93,11 +93,10 @@ class MultiHeadAttention(nn.Module): Default value is 0.0. Raises: - ValueError: - If ``num_heads % num_kv_heads != 0``, **or** - If ``embed_dim % num_heads != 0``, **or** - If ``attn_dropout < 0`` or ``attn_dropout > 1``, **or** - if q_norm is defined without k_norm or vice versa + ValueError: If ``num_heads % num_kv_heads != 0`` + ValueError: If ``embed_dim % num_heads != 0`` + ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` + ValueError: if q_norm is defined without k_norm or vice versa """ def __init__( diff --git a/torchtune/modules/_export/kv_cache.py b/torchtune/modules/_export/kv_cache.py index ad41de8859..8e0b7047e5 100644 --- a/torchtune/modules/_export/kv_cache.py +++ b/torchtune/modules/_export/kv_cache.py @@ -95,7 +95,7 @@ def update( Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively. Raises: - AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. #noqa + AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length. ValueError: if the batch size of the new key (or value) tensor is greater than the batch size used during cache setup. """