Skip to content

Commit

Permalink
fix: πŸ› Fix the torch.median() on apple gpu (#94)
Browse files Browse the repository at this point in the history
βœ… Closes: #1
  • Loading branch information
Alndaly authored Jul 5, 2024
1 parent c5f1c0a commit ef55b8f
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion lib_layerdiffusion/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,16 @@ def estimate_augmented(self, pixel, latent):
result += [eps]

result = torch.stack(result, dim=0)
median = torch.median(result, dim=0).values
if self.load_device == torch.device("mps"):
'''
In case that apple silicon devices would crash when calling torch.median() on tensors
in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median()
and then move the result back to gpu.
'''
median = torch.median(result.cpu(), dim=0).values
median = median.to(device=self.load_device, dtype=self.dtype)
else:
median = torch.median(result, dim=0).values
return median

@torch.no_grad()
Expand Down

0 comments on commit ef55b8f

Please sign in to comment.