Skip to content

Commit d2f3f15

Browse files
mntssbryce13950
andauthored
Fix LLama RoPE (#910)
Co-authored-by: Bryce Meyer <[email protected]>
1 parent 3212375 commit d2f3f15

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

transformer_lens/HookedTransformerConfig.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ class HookedTransformerConfig:
262262
NTK_by_parts_low_freq_factor: float = 1.0
263263
NTK_by_parts_high_freq_factor: float = 4.0
264264
NTK_by_parts_factor: float = 8.0
265+
NTK_original_ctx_len: int = 8192
265266

266267
def __post_init__(self):
267268
if self.n_heads == -1:

transformer_lens/components/abstract_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def calculate_sin_cos_rotary(
504504
factor = self.cfg.NTK_by_parts_factor
505505
low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
506506
high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
507-
old_context_len = n_ctx
507+
old_context_len = self.cfg.NTK_original_ctx_len
508508

509509
low_freq_wavelen = old_context_len / low_freq_factor
510510
high_freq_wavelen = old_context_len / high_freq_factor

transformer_lens/loading_from_pretrained.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
947947
"NTK_by_parts_low_freq_factor": 1.0,
948948
"NTK_by_parts_high_freq_factor": 4.0,
949949
"NTK_by_parts_factor": 32.0,
950+
"NTK_original_ctx_len": 8192,
950951
}
951952
elif "Llama-3.2-3B" in official_model_name:
952953
cfg_dict = {
@@ -971,6 +972,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
971972
"NTK_by_parts_low_freq_factor": 1.0,
972973
"NTK_by_parts_high_freq_factor": 4.0,
973974
"NTK_by_parts_factor": 32.0,
975+
"NTK_original_ctx_len": 8192,
974976
}
975977
elif "Llama-3.3-70B" in official_model_name:
976978
cfg_dict = {
@@ -995,6 +997,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
995997
"NTK_by_parts_low_freq_factor": 1.0,
996998
"NTK_by_parts_high_freq_factor": 4.0,
997999
"NTK_by_parts_factor": 8.0,
1000+
"NTK_original_ctx_len": 8192,
9981001
}
9991002
elif "Llama-3.1-8B" in official_model_name:
10001003
cfg_dict = {
@@ -1019,6 +1022,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
10191022
"NTK_by_parts_low_freq_factor": 1.0,
10201023
"NTK_by_parts_high_freq_factor": 4.0,
10211024
"NTK_by_parts_factor": 8.0,
1025+
"NTK_original_ctx_len": 8192,
10221026
}
10231027
elif "Llama-3.1-70B" in official_model_name:
10241028
cfg_dict = {
@@ -1043,6 +1047,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
10431047
"NTK_by_parts_low_freq_factor": 1.0,
10441048
"NTK_by_parts_high_freq_factor": 4.0,
10451049
"NTK_by_parts_factor": 8.0,
1050+
"NTK_original_ctx_len": 8192,
10461051
}
10471052
elif architecture == "GPTNeoForCausalLM":
10481053
cfg_dict = {

0 commit comments

Comments
 (0)