diff --git a/transformer_lens/HookedTransformerConfig.py b/transformer_lens/HookedTransformerConfig.py index 4458705de..02eaefe72 100644 --- a/transformer_lens/HookedTransformerConfig.py +++ b/transformer_lens/HookedTransformerConfig.py @@ -262,6 +262,7 @@ class HookedTransformerConfig: NTK_by_parts_low_freq_factor: float = 1.0 NTK_by_parts_high_freq_factor: float = 4.0 NTK_by_parts_factor: float = 8.0 + NTK_original_ctx_len: int = 8192 def __post_init__(self): if self.n_heads == -1: diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index 3b69bc738..87030b23d 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -504,7 +504,7 @@ def calculate_sin_cos_rotary( factor = self.cfg.NTK_by_parts_factor low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor - old_context_len = n_ctx + old_context_len = self.cfg.NTK_original_ctx_len low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index 952d2bf9b..ae1d4cfb1 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -947,6 +947,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_low_freq_factor": 1.0, "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 32.0, + "NTK_original_ctx_len": 8192, } elif "Llama-3.2-3B" in official_model_name: cfg_dict = { @@ -971,6 +972,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_low_freq_factor": 1.0, "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 32.0, + "NTK_original_ctx_len": 8192, } elif "Llama-3.3-70B" in official_model_name: cfg_dict = { @@ -995,6 +997,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_low_freq_factor": 1.0, "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 8.0, + "NTK_original_ctx_len": 8192, } elif "Llama-3.1-8B" in official_model_name: cfg_dict = { @@ -1019,6 +1022,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_low_freq_factor": 1.0, "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 8.0, + "NTK_original_ctx_len": 8192, } elif "Llama-3.1-70B" in official_model_name: cfg_dict = { @@ -1043,6 +1047,7 @@ def convert_hf_model_config(model_name: str, **kwargs): "NTK_by_parts_low_freq_factor": 1.0, "NTK_by_parts_high_freq_factor": 4.0, "NTK_by_parts_factor": 8.0, + "NTK_original_ctx_len": 8192, } elif architecture == "GPTNeoForCausalLM": cfg_dict = {