Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move logits.float() call #308

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,9 @@ def lce_forward(
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def lce_forward(

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if self.training and (labels is not None):
Expand All @@ -116,6 +115,8 @@ def lce_forward(
lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
elif labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,11 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,9 @@ def lce_forward(

else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
3 changes: 2 additions & 1 deletion src/liger_kernel/transformers/model/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ def lce_forward(
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
if labels is not None:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
Expand Down
Loading