diff --git a/olmo/model.py b/olmo/model.py index 82dbcf0dc..7f6e56aa1 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -736,11 +736,13 @@ def forward( # apply norm before if not self.config.norm_after: if self._activation_checkpoint_fn is not None: - qkv = self._activation_checkpoint_fn(self.attn_norm, x) + h = self._activation_checkpoint_fn(self.attn_norm, x) else: - qkv = self.attn_norm(x) + h = self.attn_norm(x) + else: + h = x - qkv = self.att_proj(qkv) + qkv = self.att_proj(h) if self.config.clip_qkv is not None: qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)