Skip to content

Commit

Permalink
Merge pull request #14546 from AUTOMATIC1111/fix-oft-dtype
Browse files Browse the repository at this point in the history
Fix dtype casting in OFT module
  • Loading branch information
AUTOMATIC1111 authored Jan 6, 2024
2 parents a4ee640 + f8f38c7 commit 8b6848c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions extensions-builtin/Lora/network_oft.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights):
self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)

def calc_updown(self, orig_weight):
oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
oft_blocks = self.oft_blocks.to(orig_weight.device)
eye = torch.eye(self.block_size, device=self.oft_blocks.device)

if self.is_kohya:
Expand All @@ -66,7 +66,7 @@ def calc_updown(self, orig_weight):
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
oft_blocks = torch.matmul(eye + block_Q, (eye - block_Q).float().inverse())

R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
R = oft_blocks.to(orig_weight.device)

# This errors out for MultiheadAttention, might need to be handled up-stream
merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
Expand All @@ -77,6 +77,6 @@ def calc_updown(self, orig_weight):
)
merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')

updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
updown = merged_weight.to(orig_weight.device) - orig_weight.to(merged_weight.dtype)
output_shape = orig_weight.shape
return self.finalize_updown(updown, orig_weight, output_shape)

0 comments on commit 8b6848c

Please sign in to comment.