diff --git a/litgpt/config.py b/litgpt/config.py index 4f5205f832..cba52e3374 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -38,6 +38,7 @@ class Config: norm_class_name: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" norm_eps: float = 1e-5 norm_qk: bool = False + norm_qk_type: Literal["default", "olmo2"] = "default" post_attention_norm: bool = False post_mlp_norm: bool = False parallel_residual: bool = True @@ -91,6 +92,8 @@ class Config: scale_embeddings: bool = False lm_head_bias: bool = False final_logit_softcapping: Optional[float] = None + norm_1: bool = True + norm_2: bool = True # The base period of the RoPE embeddings for local attention. # If not provided, rope_theta will be used for both local and global attention. rope_local_base_freq: Optional[float] = None @@ -930,6 +933,68 @@ def norm_class(self) -> Type: configs.extend(olmo) +olmo2 = [ + # https://huggingface.co/allenai/OLMo-2-1124-7B/blob/main/config.json + dict( + name="OLMo-2-1124-7B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-7B{}"), + vocab_size=100278, + padded_vocab_size=100352, + block_size=4096, + n_embd=4096, + n_layer=32, + n_head=32, + n_query_groups=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=11008, + rope_base=500000, + norm_qk=True, + post_mlp_norm=True, + norm_1=False, + norm_2=False, + norm_qk_type="olmo2", + post_attention_norm=True, + ), + # https://huggingface.co/allenai/OLMo-2-1124-13B/blob/main/config.json + dict( + name="OLMo-2-1124-13B{}", + hf_config=dict(org="allenai", name="OLMo-2-1124-13B{}"), + vocab_size=100278, + padded_vocab_size=100352, + block_size=4096, + n_embd=5120, + n_layer=40, + n_head=40, + n_query_groups=40, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + norm_eps=1e-06, + intermediate_size=13824, + rope_base=500000, + norm_qk=True, + post_mlp_norm=True, + norm_1=False, + norm_2=False, + norm_qk_type="olmo2", + post_attention_norm=True, + ), +] + +for c in olmo2: + for kind in ("", "-SFT", "-DPO", "-Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + ############### # Google Gemma ############### diff --git a/litgpt/model.py b/litgpt/model.py index 24d952340e..f2c3c99f6b 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -271,12 +271,16 @@ def __init__( " (non-parallel residual and shared attention norm)." ) - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.norm_1 = nn.Identity() if not config.norm_1 else config.norm_class(config.n_embd, eps=config.norm_eps) self.attn = CausalSelfAttention(config, block_idx) self.post_attention_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() ) - self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) + self.norm_2 = ( + nn.Identity() + if not config.norm_2 + else (None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps)) + ) self.mlp = config.mlp_class(config) self.post_mlp_norm = ( config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() @@ -325,6 +329,7 @@ def forward( else: x = attention_output + x x_normed = self.norm_2(x) + return self.post_mlp_norm(self.mlp(x_normed)) + x @@ -346,8 +351,12 @@ def __init__(self, config: Config, block_idx: int) -> None: self.apply_sliding_window_attention = config.sliding_window_indices[block_idx] if config.norm_qk: - self.norm_q = config.norm_class(config.head_size, eps=config.norm_eps) - self.norm_k = config.norm_class(config.head_size, eps=config.norm_eps) + norm_q_size = config.n_head * config.head_size if config.norm_qk_type == "olmo2" else config.head_size + norm_k_size = ( + config.n_query_groups * config.head_size if config.norm_qk_type == "olmo2" else config.head_size + ) + self.norm_q = config.norm_class(norm_q_size, eps=config.norm_eps) + self.norm_k = config.norm_class(norm_k_size, eps=config.norm_eps) else: self.norm_q = self.norm_k = None @@ -387,6 +396,10 @@ def forward( # Split qkv into query, key and value matrices. q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) + if self.config.norm_qk and self.config.norm_qk_type == "olmo2": + q = self.norm_q(q) + k = self.norm_k(k) + # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs). q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) @@ -400,7 +413,7 @@ def forward( k = k.transpose(1, 2) # (B, nh_k, T, hs) v = v.transpose(1, 2) # (B, nh_v, T, hs) - if self.config.norm_qk: + if self.config.norm_qk and self.config.norm_qk_type == "default": q = self.norm_q(q) k = self.norm_k(k) diff --git a/litgpt/prompts.py b/litgpt/prompts.py index 3afe81c126..2c20f19c47 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -472,6 +472,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Llama3() if re.search("Llama-3.*-Instruct-*", model_name): return Llama3() + if re.search("OLMo-2.*-(Instruct|SFT|DPO)", model_name): + return Llama3() if re.search("R1", model_name): return R1Base() if re.search("FreeWilly2", model_name): diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 47ebc41edf..fd664cd040 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -533,6 +533,85 @@ def copy_weights_qwen_2_5( pbar.update(progress_per_file) +def copy_weights_olmo2( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, + pbar: Optional[tqdm] = None, + progress_per_file: Optional[float] = None, + debug_mode: Optional[bool] = False, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight", + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.post_attention_norm.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.post_attention_norm.bias", + "model.layers.{}.post_feedforward_layernorm.weight": "transformer.h.{}.post_mlp_norm.weight", + "model.norm.weight": "transformer.ln_f.weight", + "model.norm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight", + } + ) + else: + raise NotImplementedError + + if progress_per_file is not None: + progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights)) + + for from_name, param in hf_weights.items(): + name_template, *ids = layer_template(from_name, num_matches=2) + to_name = weight_map[name_template] + param = load_param(param, from_name, dtype, verbose=debug_mode) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + qkv = qkv_weights.setdefault(ids[0], defaultdict(dict)) + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + if to_name is None: + continue + to_name = to_name.format(*ids) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if progress_per_file is not None: + pbar.update(progress_per_file) + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # qkv is split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode) + qkv = torch.cat((q, k, v)) + state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + if progress_per_file is not None: + pbar.update(progress_per_file) + + def copy_weights_qwen_3( config: Config, qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], @@ -693,6 +772,10 @@ def convert_hf_checkpoint( # holder to reconstitute the split q, k, v qkv_weights = {} copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights) + elif model_name.lower().startswith("olmo-2-"): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_olmo2, config, qkv_weights) elif model_name.lower().startswith("qwen3"): # holder to reconstitute the split q, k, v qkv_weights = {} diff --git a/litgpt/scripts/convert_lit_checkpoint.py b/litgpt/scripts/convert_lit_checkpoint.py index 232070b1fa..3718e6e671 100644 --- a/litgpt/scripts/convert_lit_checkpoint.py +++ b/litgpt/scripts/convert_lit_checkpoint.py @@ -393,6 +393,64 @@ def copy_weights_qwen_2_5( state_dict[to_name] = param +def copy_weights_olmo2( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight", + "transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight", + "transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.post_mlp_norm.weight": "model.layers.{}.post_feedforward_layernorm.weight", + "transformer.ln_f.weight": "model.norm.weight", + "transformer.ln_f.bias": "model.norm.bias", + "lm_head.weight": "lm_head.weight", + } + if config.mlp_class_name in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight", + } + ) + else: + raise NotImplementedError + + for from_name, param in lit_weights.items(): + if from_name == "lm_head.weight" and untie_weights: + continue + name_template, *ids = layer_template(from_name, num_matches=2) + param = load_param(param, from_name, None) + if from_name.endswith(".attn.qkv.weight"): + to_names = ( + "model.layers.{}.self_attn.q_proj.weight".format(*ids), + "model.layers.{}.self_attn.k_proj.weight".format(*ids), + "model.layers.{}.self_attn.v_proj.weight".format(*ids), + ) + params = param.split( + ( + config.n_head * config.head_size, + config.n_query_groups * config.head_size, + config.n_query_groups * config.head_size, + ) + ) + else: + to_names = (weight_map[name_template].format(*ids),) + params = (param,) + + for to_name, param in zip(to_names, params): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + def copy_weights_qwen_3( config: Config, state_dict: Dict[str, torch.Tensor], @@ -487,6 +545,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None: copy_fn = partial(copy_weights_phi, config) elif config.name.lower().startswith(("qwen2.5", "qwq")): copy_fn = partial(copy_weights_qwen_2_5, config) + elif config.name.lower().startswith("olmo-2-"): + copy_fn = partial(copy_weights_olmo2, config) elif config.name.lower().startswith("qwen3"): copy_fn = partial(copy_weights_qwen_3, config) elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): diff --git a/tests/test_model.py b/tests/test_model.py index 9f55f5d373..ab5dec50b2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -29,6 +29,7 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.olmo import OlmoConfig, OlmoForCausalLM +from transformers.models.olmo2 import Olmo2Config, Olmo2ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from transformers.models.qwen3 import Qwen3Config, Qwen3ForCausalLM @@ -41,6 +42,7 @@ copy_weights_gemma_3, copy_weights_gpt_neox, copy_weights_hf_llama, + copy_weights_olmo2, copy_weights_phi, copy_weights_qwen_2_5, copy_weights_qwen_3, @@ -634,6 +636,66 @@ def test_against_olmo(model_name, device, dtype): torch.testing.assert_close(ours_y, theirs_y) +@torch.inference_mode() +@pytest.mark.parametrize("model_name", ("OLMo-2-1124-7B", "OLMo-2-1124-13B")) +@pytest.mark.parametrize( + ("device", "dtype"), + [ + (torch.device("cpu"), torch.float32), + pytest.param( + torch.device("cuda"), + torch.float16, + marks=[ + # the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input + # is slightly different + pytest.mark.xfail(raises=AssertionError, strict=False), + _RunIf(min_cuda_gpus=1), + ], + ), + ], +) +def test_against_olmo2(model_name, device, dtype): + torch.set_default_dtype(dtype) + + ours_config = Config.from_name( + model_name, + padded_vocab_size=10000, + n_layer=2, + n_head=8, + n_embd=32, + n_query_groups=2, + intermediate_size=86, + ) + T = 5 + theirs_config = Olmo2Config( + vocab_size=ours_config.padded_vocab_size, + hidden_size=ours_config.n_embd, + intermediate_size=ours_config.intermediate_size, + num_hidden_layers=ours_config.n_layer, + num_attention_heads=ours_config.n_head, + num_key_value_heads=ours_config.n_query_groups, + max_positional_embeddings=T, + rms_norm_eps=ours_config.norm_eps, + attention_bias=ours_config.bias, + rope_theta=ours_config.rope_base, + ) + assert ours_config.intermediate_size == theirs_config.intermediate_size + + theirs_model = Olmo2ForCausalLM(theirs_config).to(device) + theirs_state_dict = theirs_model.state_dict() + state_dict = {} + copy_weights_olmo2(ours_config, {}, state_dict, theirs_state_dict) + ours_model = GPT(ours_config).to(device) + ours_model.load_state_dict(state_dict) + + # test end to end + x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device) + assert x.size(1) == T + ours_y = ours_model(x) + theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float + torch.testing.assert_close(ours_y, theirs_y) + + @torch.inference_mode() @pytest.mark.parametrize( ("device", "dtype"),