Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/diffusers/pipelines/stable_diffusion_xl/watermark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,15 @@ def apply_watermark(self, images: torch.FloatTensor):

images = (255 * (images / 2 + 0.5)).cpu().permute(0, 2, 3, 1).float().numpy()

# Convert RGB to BGR, which is the channel order expected by the watermark encoder.
images = images[:, :, :, ::-1]

images = [self.encoder.encode(image, "dwtDct") for image in images]

images = torch.from_numpy(np.array(images)).permute(0, 3, 1, 2)
# Convert BGR back to RGB
images = np.array(images)[:, :, :, ::-1]

images = torch.from_numpy(images).permute(0, 3, 1, 2)

images = torch.clamp(2 * (images / 255 - 0.5), min=-1.0, max=1.0)
return images