diff --git a/lib_layerdiffusion/models.py b/lib_layerdiffusion/models.py index c9e6d3c..19396f7 100644 --- a/lib_layerdiffusion/models.py +++ b/lib_layerdiffusion/models.py @@ -299,7 +299,16 @@ def estimate_augmented(self, pixel, latent): result += [eps] result = torch.stack(result, dim=0) - median = torch.median(result, dim=0).values + if self.load_device == torch.device("mps"): + ''' + In case that apple silicon devices would crash when calling torch.median() on tensors + in gpu vram with dimensions higher than 4, we move it to cpu, call torch.median() + and then move the result back to gpu. + ''' + median = torch.median(result.cpu(), dim=0).values + median = median.to(device=self.load_device, dtype=self.dtype) + else: + median = torch.median(result, dim=0).values return median @torch.no_grad()