Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ class HookedTransformerConfig:
NTK_by_parts_low_freq_factor: float = 1.0
NTK_by_parts_high_freq_factor: float = 4.0
NTK_by_parts_factor: float = 8.0
NTK_original_ctx_len: int = 8192

def __post_init__(self):
if self.n_heads == -1:
Expand Down
2 changes: 1 addition & 1 deletion transformer_lens/components/abstract_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def calculate_sin_cos_rotary(
factor = self.cfg.NTK_by_parts_factor
low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor
high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor
old_context_len = n_ctx
old_context_len = self.cfg.NTK_original_ctx_len

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
Expand Down
5 changes: 5 additions & 0 deletions transformer_lens/loading_from_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 32.0,
"NTK_original_ctx_len": 8192,
}
elif "Llama-3.2-3B" in official_model_name:
cfg_dict = {
Expand All @@ -971,6 +972,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 32.0,
"NTK_original_ctx_len": 8192,
}
elif "Llama-3.3-70B" in official_model_name:
cfg_dict = {
Expand All @@ -995,6 +997,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 8.0,
"NTK_original_ctx_len": 8192,
}
elif "Llama-3.1-8B" in official_model_name:
cfg_dict = {
Expand All @@ -1019,6 +1022,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 8.0,
"NTK_original_ctx_len": 8192,
}
elif "Llama-3.1-70B" in official_model_name:
cfg_dict = {
Expand All @@ -1043,6 +1047,7 @@ def convert_hf_model_config(model_name: str, **kwargs):
"NTK_by_parts_low_freq_factor": 1.0,
"NTK_by_parts_high_freq_factor": 4.0,
"NTK_by_parts_factor": 8.0,
"NTK_original_ctx_len": 8192,
}
elif architecture == "GPTNeoForCausalLM":
cfg_dict = {
Expand Down
Loading