diff --git a/README.md b/README.md
index ff6f713..1877875 100644
--- a/README.md
+++ b/README.md
@@ -1,219 +1,179 @@
-# Stable Diffusion with Aesthetic Gradients (WIP)
+# Stable Diffusion with Aesthetic Gradients
-Code will be uploaded here over the next days.
+This is the codebase for the article [Personalizing Text-to-Image Generation via Aesthetic Gradients](https://arxiv.org/abs/2209.12330):
+> This work proposes aesthetic gradients, a method to personalize a CLIP-conditioned diffusion model by guiding the generative process towards custom aesthetics defined by the user from a set of images. The approach is validated with qualitative and quantitative experiments, using the recent stable diffusion model and several aesthetically-filtered datasets.
-*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
+In particular, this reposiory allows the user to use the aesthetic gradients technique described in the previous paper to personalize stable diffusion.
-[**High-Resolution Image Synthesis with Latent Diffusion Models**](https://ommer-lab.com/research/latent-diffusion-models/)
-[Robin Rombach](https://github.com/rromb)\*,
-[Andreas Blattmann](https://github.com/ablattmann)\*,
-[Dominik Lorenz](https://github.com/qp-qp)\,
-[Patrick Esser](https://github.com/pesser),
-[Björn Ommer](https://hci.iwr.uni-heidelberg.de/Staff/bommer)
-_[CVPR '22 Oral](https://openaccess.thecvf.com/content/CVPR2022/html/Rombach_High-Resolution_Image_Synthesis_With_Latent_Diffusion_Models_CVPR_2022_paper.html) |
-[GitHub](https://github.com/CompVis/latent-diffusion) | [arXiv](https://arxiv.org/abs/2112.10752) | [Project page](https://ommer-lab.com/research/latent-diffusion-models/)_
-![txt2img-stable2](assets/stable-samples/txt2img/merged-0006.png)
-[Stable Diffusion](#stable-diffusion-v1) is a latent text-to-image diffusion
-model.
-Thanks to a generous compute donation from [Stability AI](https://stability.ai/) and support from [LAION](https://laion.ai/), we were able to train a Latent Diffusion Model on 512x512 images from a subset of the [LAION-5B](https://laion.ai/blog/laion-5b/) database.
-Similar to Google's [Imagen](https://arxiv.org/abs/2205.11487),
-this model uses a frozen CLIP ViT-L/14 text encoder to condition the model on text prompts.
-With its 860M UNet and 123M text encoder, the model is relatively lightweight and runs on a GPU with at least 10GB VRAM.
-See [this section](#stable-diffusion-v1) below and the [model card](https://huggingface.co/CompVis/stable-diffusion).
+## Prerequisites
-
-## Requirements
-A suitable [conda](https://conda.io/) environment named `ldm` can be created
-and activated with:
+This is a fork of the original stable-diffusion repository, so the prerequisites are the same as the [original repository](https://github.com/CompVis/stable-diffusion/). In particular, when cloning this repo, install the library as:
+```bash
+pip install -e .
```
-conda env create -f environment.yaml
-conda activate ldm
+
+## Usage
+
+You can use the same arguments as with the original stable diffusion repository. The script `scripts/txt2img.py` has the additional arguments:
+
+- `--aesthetic_steps`: number of optimization steps when doing the personalization. For a given prompt, it is recommended to start with few steps (2 or 3), and then gradually increase it (trying 5, 10, 15, 20, etc). The greater the value, the more the resulting image will be biased towards the aesthetic embedding.
+- `--aesthetic_lr`: learning rate for the aesthetic gradient optimization. The default value is 0.0001. This value almost usually works well enough, so you can just only tune the previous argument.
+- `--aesthetic_embedding`: path to the stored pytorch tensor (.pt format) containing the aesthetic embedding. It must be of shape 1x768 (CLIP-L/14 size). See below for computing your own aesthetic embeddings.
+
+In this repository we include all the aesthetic embeddings used in the paper. All of them are in the directory `aesthetic_embeddings`:
+* sac_8plus.pt
+* laion_7plus.pt
+* aivazovsky.pt
+* cloudcore.pt
+* gloomcore.pt
+* glowwave.pt
+
+See the paper to see how they were obtained.
+
+### Examples
+
+Let's see some examples now. This would be with the un-personalized, original SD model:
+
+```bash
+python scripts/txt2img.py --prompt "Roman city on top of a ridge, sci-fi illustration by Greg Rutkowski #sci-fi detailed vivid colors gothic concept illustration by James Gurney and Zdzislaw Beksiński vivid vivid colorsg concept illustration colorful interior" --seed 332 --plms --aesthetic_steps 0 --W 768 --aesthetic_embedding aesthetic_embeddings/laion_7plus.pt
```
-You can also update an existing [latent diffusion](https://github.com/CompVis/latent-diffusion) environment by running
+![sample](assets/grid-0131.png)
+
+If we now personalize it with the LAION embedding, note how the images get more floral patterns, as this is one common pattern of the LAION aesthetics dataset:
+```bash
+python scripts/txt2img.py --prompt "Roman city on top of a ridge, sci-fi illustration by Greg Rutkowski #sci-fi detailed vivid colors gothic concept illustration by James Gurney and Zdzislaw Beksiński vivid vivid colorsg concept illustration colorful interior" --seed 332 --plms --aesthetic_steps 5 --W 768 --aesthetic_embedding aesthetic_embeddings/laion_7plus.pt
```
-conda install pytorch torchvision -c pytorch
-pip install transformers==4.19.2 diffusers invisible-watermark
-pip install -e .
-```
+![sample](assets/grid-0133.png)
-## Stable Diffusion v1
+Increasing the number of steps more...
-Stable Diffusion v1 refers to a specific configuration of the model
-architecture that uses a downsampling-factor 8 autoencoder with an 860M UNet
-and CLIP ViT-L/14 text encoder for the diffusion model. The model was pretrained on 256x256 images and
-then finetuned on 512x512 images.
+```bash
+python scripts/txt2img.py --prompt "Roman city on top of a ridge, sci-fi illustration by Greg Rutkowski #sci-fi detailed vivid colors gothic concept illustration by James Gurney and Zdzislaw Beksiński vivid vivid colorsg concept illustration colorful interior" --seed 332 --plms --aesthetic_steps 8 --W 768 --aesthetic_embedding aesthetic_embeddings/laion_7plus.pt
+```
-*Note: Stable Diffusion v1 is a general text-to-image diffusion model and therefore mirrors biases and (mis-)conceptions that are present
-in its training data.
-Details on the training procedure and data, as well as the intended use of the model can be found in the corresponding [model card](Stable_Diffusion_v1_Model_Card.md).*
+![sample](assets/grid-0135.png)
-The weights are available via [the CompVis organization at Hugging Face](https://huggingface.co/CompVis) under [a license which contains specific use-based restrictions to prevent misuse and harm as informed by the model card, but otherwise remains permissive](LICENSE). While commercial use is permitted under the terms of the license, **we do not recommend using the provided weights for services or products without additional safety mechanisms and considerations**, since there are [known limitations and biases](Stable_Diffusion_v1_Model_Card.md#limitations-and-bias) of the weights, and research on safe and ethical deployment of general text-to-image models is an ongoing effort. **The weights are research artifacts and should be treated as such.**
+Let's see another example:
-[The CreativeML OpenRAIL M license](LICENSE) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
+```bash
+python scripts/txt2img.py --prompt "A portal towards other dimension" --plms --seed 332 --aesthetic_steps 15 --aesthetic_embedding aesthetic_embeddings/sac_8plus.pt
+```
+![sample](assets/grid-0073.png)
-### Weights
+If we increase it to 20 steps, we get a more pronounced effect:
-We currently provide the following checkpoints:
+```bash
+python scripts/txt2img.py --prompt "A portal towards other dimension" --plms --seed 332 --aesthetic_steps 20 --aesthetic_embedding aesthetic_embeddings/sac_8plus.pt
+```
-- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
- 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
-- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
- 515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally
-filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
-- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
-- `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
+![sample](assets/grid-0072.png)
-Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
-5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
-steps show the relative improvements of the checkpoints:
-![sd evaluation results](assets/v1-variants-scores.jpg)
+We can set the steps to 0 to get the outputs for the original stable diffusion model:
+```bash
+python scripts/txt2img.py --prompt "A portal towards other dimension" --plms --seed 332 --aesthetic_steps 0 --aesthetic_embedding aesthetic_embeddings/sac_8plus.pt
+```
+![sample](assets/grid-0075.png)
-### Text-to-Image with Stable Diffusion
-![txt2img-stable2](assets/stable-samples/txt2img/merged-0005.png)
-![txt2img-stable2](assets/stable-samples/txt2img/merged-0007.png)
+Note that since we have used the SAC dataset for the personalization, the optimized results are more biased towards fantasy aesthetics.
-Stable Diffusion is a latent diffusion model conditioned on the (non-pooled) text embeddings of a CLIP ViT-L/14 text encoder.
-We provide a [reference script for sampling](#reference-sampling-script), but
-there also exists a [diffusers integration](#diffusers-integration), which we
-expect to see more active community development.
+Now we turn to another example.
-#### Reference Sampling Script
+To see more examples, have a look at https://arxiv.org/abs/2209.12330
-We provide a reference sampling script, which incorporates
+## Using your own embeddings
-- a [Safety Checker Module](https://github.com/CompVis/stable-diffusion/pull/36),
- to reduce the probability of explicit outputs,
-- an [invisible watermarking](https://github.com/ShieldMnt/invisible-watermark)
- of the outputs, to help viewers [identify the images as machine-generated](scripts/tests/test_watermark.py).
+If you want to use your own aesthetic embeddings from a set of images, you can use the script `scripts/gen_aesthetic_embedding.py`. This script takes as input a directory containing images, and outputs a pytorch tensor containing the aesthetic embedding, so you can use it as in the previous commands.
-After [obtaining the `stable-diffusion-v1-*-original` weights](#weights), link them
-```
-mkdir -p models/ldm/stable-diffusion-v1/
-ln -s models/ldm/stable-diffusion-v1/model.ckpt
-```
-and sample with
-```
-python scripts/txt2img.py --prompt "a photograph of an astronaut riding a horse" --plms
-```
+Some examples with three works from the painter Aivazovsky: [reference_images/aivazovsky](reference_images/aivazovsky)
-By default, this uses a guidance scale of `--scale 7.5`, [Katherine Crowson's implementation](https://github.com/CompVis/latent-diffusion/pull/51) of the [PLMS](https://arxiv.org/abs/2202.09778) sampler,
-and renders images of size 512x512 (which it was trained on) in 50 steps. All supported arguments are listed below (type `python scripts/txt2img.py --help`).
+```bash
+python scripts/txt2img.py --prompt "a painting of a tree, oil on canvas" --plms --seed 332 --aesthetic_steps 50 --aesthetic_embedding aesthetic_embeddings/aivazovsky.pt
+```
+![sample](assets/grid-0089.png)
-```commandline
-usage: txt2img.py [-h] [--prompt [PROMPT]] [--outdir [OUTDIR]] [--skip_grid] [--skip_save] [--ddim_steps DDIM_STEPS] [--plms] [--laion400m] [--fixed_code] [--ddim_eta DDIM_ETA]
- [--n_iter N_ITER] [--H H] [--W W] [--C C] [--f F] [--n_samples N_SAMPLES] [--n_rows N_ROWS] [--scale SCALE] [--from-file FROM_FILE] [--config CONFIG] [--ckpt CKPT]
- [--seed SEED] [--precision {full,autocast}]
+Note that just adding the modifier "by Aivazoysky" to the prompt does not work so well:
-optional arguments:
- -h, --help show this help message and exit
- --prompt [PROMPT] the prompt to render
- --outdir [OUTDIR] dir to write results to
- --skip_grid do not save a grid, only individual samples. Helpful when evaluating lots of samples
- --skip_save do not save individual samples. For speed measurements.
- --ddim_steps DDIM_STEPS
- number of ddim sampling steps
- --plms use plms sampling
- --laion400m uses the LAION400M model
- --fixed_code if enabled, uses the same starting code across samples
- --ddim_eta DDIM_ETA ddim eta (eta=0.0 corresponds to deterministic sampling
- --n_iter N_ITER sample this often
- --H H image height, in pixel space
- --W W image width, in pixel space
- --C C latent channels
- --f F downsampling factor
- --n_samples N_SAMPLES
- how many samples to produce for each given prompt. A.k.a. batch size
- --n_rows N_ROWS rows in the grid (default: n_samples)
- --scale SCALE unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))
- --from-file FROM_FILE
- if specified, load prompts from this file
- --config CONFIG path to config which constructs model
- --ckpt CKPT path to checkpoint of model
- --seed SEED the seed (for reproducible sampling)
- --precision {full,autocast}
- evaluate at this precision
+```bash
+python scripts/txt2img.py --prompt "a painting of a tree, oil on canvas by Aivazovsky" --plms --seed 332 --aesthetic_steps 0 --aesthetic_embedding aesthetic_embeddings/aivazovsky.pt
```
-Note: The inference config for all v1 versions is designed to be used with EMA-only checkpoints.
-For this reason `use_ema=False` is set in the configuration, otherwise the code will try to switch from
-non-EMA to EMA weights. If you want to examine the effect of EMA vs no EMA, we provide "full" checkpoints
-which contain both types of weights. For these, `use_ema=False` will load and use the non-EMA weights.
+![sample](assets/grid-0091.png)
-#### Diffusers Integration
+Another example, mixing the styles of two painters (one in the prompt, the other as the aesthetic embedding):
-A simple way to download and sample Stable Diffusion is by using the [diffusers library](https://github.com/huggingface/diffusers/tree/main#new--stable-diffusion-is-now-fully-compatible-with-diffusers):
-```py
-# make sure you're logged in with `huggingface-cli login`
-from torch import autocast
-from diffusers import StableDiffusionPipeline
+```bash
+96 python scripts/txt2img.py --prompt "a gothic cathedral in a stunning landscape by Jean-Honoré Fragonard" --plms --seed 139782398 --aesthetic_steps 12 --aesthetic_embedding aesthetic_embeddings/aivazovsky.pt
+```
+![sample](assets/grid-0096.png)
-pipe = StableDiffusionPipeline.from_pretrained(
- "CompVis/stable-diffusion-v1-4",
- use_auth_token=True
-).to("cuda")
+Whereas the original SD would output this:
-prompt = "a photo of an astronaut riding a horse on mars"
-with autocast("cuda"):
- image = pipe(prompt)["sample"][0]
-
-image.save("astronaut_rides_horse.png")
+```bash
+python scripts/txt2img.py --prompt "a gothic cathedral in a stunning landscape by Jean-Honoré Fragonard" --plms --seed 139782398 --aesthetic_steps 0 --aesthetic_embedding aesthetic_embeddings/aivazovsky.pt
```
+![sample](assets/grid-0097.png)
-### Image Modification with Stable Diffusion
+## Using it with other fine-tuned SD models
-By using a diffusion-denoising mechanism as first proposed by [SDEdit](https://arxiv.org/abs/2108.01073), the model can be used for different
-tasks such as text-guided image-to-image translation and upscaling. Similar to the txt2img sampling script,
-we provide a script to perform image modification with Stable Diffusion.
+The aesthetic gradients technique can be used with any fine-tuned SD model. For example, you can use it with the [Pokemon finetune](https://replicate.com/lambdal/text-to-pokemon):
-The following describes an example where a rough sketch made in [Pinta](https://www.pinta-project.com/) is converted into a detailed artwork.
+```bash
+python scripts/txt2img.py --prompt "robotic cat with wings" --plms --seed 7 --ckpt ../stable-diffusion/ema-only-epoch\=000142.ckpt --aesthetic_steps 15 --aesthetic_embedding aesthetic_embeddings/laion_7plus.pt
```
-python scripts/img2img.py --prompt "A fantasy landscape, trending on artstation" --init-img --strength 0.8
+
+![sample](assets/grid-0033.png)
+
+The previous prompt was personalized with the LAION aesthetics embedding, so it has more childish-like than using just the original model:
+
+```bash
+python scripts/txt2img.py --prompt "robotic cat with wings" --plms --seed 7 --ckpt ../stable-diffusion/ema-only-epoch\=000142.ckpt --aesthetic_steps 0 --aesthetic_embedding aesthetic_embeddings/laion_7plus.pt
```
-Here, strength is a value between 0.0 and 1.0, that controls the amount of noise that is added to the input image.
-Values that approach 1.0 allow for lots of variations but will also produce images that are not semantically consistent with the input. See the following example.
+![sample](assets/grid-0035.png)
-**Input**
+Another example:
-![sketch-in](assets/stable-samples/img2img/sketch-mountains-input.jpg)
+```bash
-**Outputs**
+python scripts/txt2img.py --prompt "Dragonite" --plms --seed 7 --ckpt ../stable-diffusion/ema-only-epoch\=000142.ckpt --aesthetic_steps 10 --aesthetic_embedding aesthetic_embeddings/sac_8plus.pt
+```
+
+![sample](assets/grid-0047.png)
-![out3](assets/stable-samples/img2img/mountains-3.png)
-![out2](assets/stable-samples/img2img/mountains-2.png)
+```bash
+
+python scripts/txt2img.py --prompt "Dragonite" --plms --seed 7 --ckpt ../stable-diffusion/ema-only-epoch\=000142.ckpt --aesthetic_steps 0 --aesthetic_embedding aesthetic_embeddings/sac_8plus.pt
+```
-This procedure can, for example, also be used to upscale samples from the base model.
+![sample](assets/grid-0043.png)
-## Comments
-- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
-and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
-Thanks for open-sourcing!
-- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
+## Citation
-## BibTeX
+If you find this is useful for your research, please cite our paper:
```
-@misc{rombach2021highresolution,
- title={High-Resolution Image Synthesis with Latent Diffusion Models},
- author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
- year={2021},
- eprint={2112.10752},
- archivePrefix={arXiv},
- primaryClass={cs.CV}
+@article{gallego2022personalizing,
+ title={Personalizing Text-to-Image Generation via Aesthetic Gradients},
+ author={Gallego, Victor},
+ journal={arXiv preprint arXiv:2209.12330},
+ year={2022}
}
```
+
+
+
diff --git a/aesthetic_embeddings/aivazovsky.pt b/aesthetic_embeddings/aivazovsky.pt
new file mode 100644
index 0000000..745724b
Binary files /dev/null and b/aesthetic_embeddings/aivazovsky.pt differ
diff --git a/aesthetic_embeddings/cloudcore.pt b/aesthetic_embeddings/cloudcore.pt
new file mode 100644
index 0000000..dac999e
Binary files /dev/null and b/aesthetic_embeddings/cloudcore.pt differ
diff --git a/aesthetic_embeddings/gloomcore.pt b/aesthetic_embeddings/gloomcore.pt
new file mode 100644
index 0000000..661959f
Binary files /dev/null and b/aesthetic_embeddings/gloomcore.pt differ
diff --git a/aesthetic_embeddings/glowwave.pt b/aesthetic_embeddings/glowwave.pt
new file mode 100644
index 0000000..8955fcc
Binary files /dev/null and b/aesthetic_embeddings/glowwave.pt differ
diff --git a/aesthetic_embeddings/laion_7plus.pt b/aesthetic_embeddings/laion_7plus.pt
new file mode 100644
index 0000000..d19c2b0
Binary files /dev/null and b/aesthetic_embeddings/laion_7plus.pt differ
diff --git a/aesthetic_embeddings/sac_8plus.pt b/aesthetic_embeddings/sac_8plus.pt
new file mode 100644
index 0000000..de81dc5
Binary files /dev/null and b/aesthetic_embeddings/sac_8plus.pt differ
diff --git a/assets/grid-0033.png b/assets/grid-0033.png
new file mode 100644
index 0000000..64acf87
Binary files /dev/null and b/assets/grid-0033.png differ
diff --git a/assets/grid-0035.png b/assets/grid-0035.png
new file mode 100644
index 0000000..111ba9b
Binary files /dev/null and b/assets/grid-0035.png differ
diff --git a/assets/grid-0043.png b/assets/grid-0043.png
new file mode 100644
index 0000000..20a8eac
Binary files /dev/null and b/assets/grid-0043.png differ
diff --git a/assets/grid-0047.png b/assets/grid-0047.png
new file mode 100644
index 0000000..9626c1c
Binary files /dev/null and b/assets/grid-0047.png differ
diff --git a/assets/grid-0072.png b/assets/grid-0072.png
new file mode 100644
index 0000000..f8cd45a
Binary files /dev/null and b/assets/grid-0072.png differ
diff --git a/assets/grid-0073.png b/assets/grid-0073.png
new file mode 100644
index 0000000..2894d60
Binary files /dev/null and b/assets/grid-0073.png differ
diff --git a/assets/grid-0075.png b/assets/grid-0075.png
new file mode 100644
index 0000000..ad46db0
Binary files /dev/null and b/assets/grid-0075.png differ
diff --git a/assets/grid-0089.png b/assets/grid-0089.png
new file mode 100644
index 0000000..128a766
Binary files /dev/null and b/assets/grid-0089.png differ
diff --git a/assets/grid-0091.png b/assets/grid-0091.png
new file mode 100644
index 0000000..db2cc77
Binary files /dev/null and b/assets/grid-0091.png differ
diff --git a/assets/grid-0096.png b/assets/grid-0096.png
new file mode 100644
index 0000000..6d739f6
Binary files /dev/null and b/assets/grid-0096.png differ
diff --git a/assets/grid-0097.png b/assets/grid-0097.png
new file mode 100644
index 0000000..63bcfc3
Binary files /dev/null and b/assets/grid-0097.png differ
diff --git a/assets/grid-0131.png b/assets/grid-0131.png
new file mode 100644
index 0000000..979b8cd
Binary files /dev/null and b/assets/grid-0131.png differ
diff --git a/assets/grid-0133.png b/assets/grid-0133.png
new file mode 100644
index 0000000..fb9264e
Binary files /dev/null and b/assets/grid-0133.png differ
diff --git a/assets/grid-0135.png b/assets/grid-0135.png
new file mode 100644
index 0000000..3f4acee
Binary files /dev/null and b/assets/grid-0135.png differ
diff --git a/configs/stable-diffusion/v1-inference-aesthetic.yaml b/configs/stable-diffusion/v1-inference-aesthetic.yaml
new file mode 100644
index 0000000..036dc23
--- /dev/null
+++ b/configs/stable-diffusion/v1-inference-aesthetic.yaml
@@ -0,0 +1,74 @@
+model:
+ base_learning_rate: 1.0e-04
+ target: ldm.models.diffusion.ddpm.LatentDiffusion
+ params:
+ linear_start: 0.00085
+ linear_end: 0.0120
+ num_timesteps_cond: 1
+ log_every_t: 200
+ timesteps: 1000
+ first_stage_key: "jpg"
+ cond_stage_key: "txt"
+ image_size: 64
+ channels: 4
+ cond_stage_trainable: false # Note: different from the one we trained before
+ conditioning_key: crossattn
+ monitor: val/loss_simple_ema
+ scale_factor: 0.18215
+ use_ema: False
+
+ scheduler_config: # 10000 warmup steps
+ target: ldm.lr_scheduler.LambdaLinearScheduler
+ params:
+ warm_up_steps: [10000]
+ cycle_lengths: [10000000000000] # incredibly large number to prevent corner cases
+ f_start: [1.e-6]
+ f_max: [1.]
+ f_min: [1.]
+
+ unet_config:
+ target: ldm.modules.diffusionmodules.openaimodel.UNetModel
+ params:
+ image_size: 32 # unused
+ in_channels: 4
+ out_channels: 4
+ model_channels: 320
+ attention_resolutions: [4, 2, 1]
+ num_res_blocks: 2
+ channel_mult: [1, 2, 4, 4]
+ num_heads: 8
+ use_spatial_transformer: True
+ transformer_depth: 1
+ context_dim: 768
+ use_checkpoint: True
+ legacy: False
+
+ first_stage_config:
+ target: ldm.models.autoencoder.AutoencoderKL
+ params:
+ embed_dim: 4
+ monitor: val/rec_loss
+ ddconfig:
+ double_z: true
+ z_channels: 4
+ resolution: 256
+ in_channels: 3
+ out_ch: 3
+ ch: 128
+ ch_mult:
+ - 1
+ - 2
+ - 4
+ - 4
+ num_res_blocks: 2
+ attn_resolutions: []
+ dropout: 0.0
+ lossconfig:
+ target: torch.nn.Identity
+
+ cond_stage_config:
+ target: ldm.modules.encoders.modules.PersonalizedCLIPEmbedder
+ params:
+ aesthetic_embedding_path: "aesthetic_embeddings/sac_8plus.pt"
+ T: 3
+ lr: 0.0001
diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py
index ededbe4..04bf35f 100644
--- a/ldm/modules/encoders/modules.py
+++ b/ldm/modules/encoders/modules.py
@@ -1,12 +1,16 @@
import torch
import torch.nn as nn
+import torch.optim as optim
from functools import partial
import clip
from einops import rearrange, repeat
-from transformers import CLIPTokenizer, CLIPTextModel
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPProcessor, CLIPModel
import kornia
-from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
+from ldm.modules.x_transformer import (
+ Encoder,
+ TransformerWrapper,
+) # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
class AbstractEncoder(nn.Module):
@@ -17,9 +21,8 @@ def encode(self, *args, **kwargs):
raise NotImplementedError
-
class ClassEmbedder(nn.Module):
- def __init__(self, embed_dim, n_classes=1000, key='class'):
+ def __init__(self, embed_dim, n_classes=1000, key="class"):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
@@ -35,11 +38,15 @@ def forward(self, batch, key=None):
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
+
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer))
+ self.transformer = TransformerWrapper(
+ num_tokens=vocab_size,
+ max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ )
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
@@ -51,18 +58,27 @@ def encode(self, x):
class BERTTokenizer(AbstractEncoder):
- """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+ """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
+
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
+
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@@ -79,20 +95,32 @@ def decode(self, text):
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
- def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
- device="cuda",use_tokenizer=True, embedding_dropout=0.0):
+
+ def __init__(
+ self,
+ n_embed,
+ n_layer,
+ vocab_size=30522,
+ max_seq_len=77,
+ device="cuda",
+ use_tokenizer=True,
+ embedding_dropout=0.0,
+ ):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
- self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
- attn_layers=Encoder(dim=n_embed, depth=n_layer),
- emb_dropout=embedding_dropout)
+ self.transformer = TransformerWrapper(
+ num_tokens=vocab_size,
+ max_seq_len=max_seq_len,
+ attn_layers=Encoder(dim=n_embed, depth=n_layer),
+ emb_dropout=embedding_dropout,
+ )
def forward(self, text):
if self.use_tknz_fn:
- tokens = self.tknz_fn(text)#.to(self.device)
+ tokens = self.tknz_fn(text) # .to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
@@ -104,29 +132,39 @@ def encode(self, text):
class SpatialRescaler(nn.Module):
- def __init__(self,
- n_stages=1,
- method='bilinear',
- multiplier=0.5,
- in_channels=3,
- out_channels=None,
- bias=False):
+ def __init__(
+ self,
+ n_stages=1,
+ method="bilinear",
+ multiplier=0.5,
+ in_channels=3,
+ out_channels=None,
+ bias=False,
+ ):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
- assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
+ assert method in [
+ "nearest",
+ "linear",
+ "bilinear",
+ "trilinear",
+ "bicubic",
+ "area",
+ ]
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
- print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
- self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
+ print(
+ f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
+ )
+ self.channel_mapper = nn.Conv2d(in_channels, out_channels, 1, bias=bias)
- def forward(self,x):
+ def forward(self, x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
-
if self.remap_output:
x = self.channel_mapper(x)
return x
@@ -134,15 +172,30 @@ def forward(self,x):
def encode(self, x):
return self(x)
-class FrozenCLIPEmbedder(AbstractEncoder):
- """Uses the CLIP transformer encoder for text (from Hugging Face)"""
- def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77):
+
+class PersonalizedCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder with the option of personalization with an aesthetic embedding"""
+
+ def __init__(
+ self,
+ version="openai/clip-vit-large-patch14",
+ device="cuda",
+ max_length=77,
+ T=5,
+ lr=0.0001,
+ aesthetic_embedding_path="aesthetic_embeddings/cloudcore.pt",
+ ):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
+ self.full_clip_processor = CLIPProcessor.from_pretrained(version)
self.device = device
self.max_length = max_length
- self.freeze()
+
+ self.T = T
+ self.lr = lr
+ self.aesthetic_embedding_path = aesthetic_embedding_path
+ # self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
@@ -150,13 +203,63 @@ def freeze(self):
param.requires_grad = False
def forward(self, text):
- batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
- return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
- tokens = batch_encoding["input_ids"].to(self.device)
- outputs = self.transformer(input_ids=tokens)
-
- z = outputs.last_hidden_state
- return z
+ with torch.enable_grad():
+ batch_encoding = self.tokenizer(
+ text,
+ truncation=True,
+ max_length=self.max_length,
+ return_length=True,
+ return_overflowing_tokens=False,
+ padding="max_length",
+ return_tensors="pt",
+ )
+ tokens = batch_encoding["input_ids"].to(self.device)
+
+ if text[0] != "":
+
+ # This is the model to be personalized
+ full_clip_model = CLIPModel.from_pretrained(
+ "openai/clip-vit-large-patch14",
+ ).to(self.device)
+
+ # We load the aesthetic embeddings
+ image_embs = torch.load(self.aesthetic_embedding_path).to(self.device)
+
+ # We compute the loss (similarity between the prompt embedding and the aesthetic embedding)
+ image_embs /= image_embs.norm(dim=-1, keepdim=True)
+ text_embs = full_clip_model.get_text_features(tokens)
+ text_embs /= text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ image_embs.T
+ loss = -sim
+ print(loss)
+
+ # lr = 0.0001
+
+ # We optimize the model to maximize the similarity
+ optimizer = optim.Adam(
+ full_clip_model.text_model.parameters(), lr=self.lr
+ )
+
+ # T = 0
+ for i in range(self.T):
+ optimizer.zero_grad()
+
+ loss.mean().backward()
+ optimizer.step()
+
+ text_embs = full_clip_model.get_text_features(tokens)
+ text_embs /= text_embs.norm(dim=-1, keepdim=True)
+ sim = text_embs @ image_embs.T
+ loss = -sim
+ print(loss)
+
+ z = full_clip_model.text_model(input_ids=tokens).last_hidden_state
+
+ else:
+ z = self.transformer(input_ids=tokens).last_hidden_state
+
+ self.freeze()
+ return z
def encode(self, text):
return self(text)
@@ -166,7 +269,15 @@ class FrozenCLIPTextEmbedder(nn.Module):
"""
Uses the CLIP transformer encoder for text.
"""
- def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
+
+ def __init__(
+ self,
+ version="ViT-L/14",
+ device="cuda",
+ max_length=77,
+ n_repeat=1,
+ normalize=True,
+ ):
super().__init__()
self.model, _ = clip.load(version, jit=False, device="cpu")
self.device = device
@@ -188,37 +299,46 @@ def forward(self, text):
def encode(self, text):
z = self(text)
- if z.ndim==2:
+ if z.ndim == 2:
z = z[:, None, :]
- z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
+ z = repeat(z, "b 1 d -> b k d", k=self.n_repeat)
return z
class FrozenClipImageEmbedder(nn.Module):
"""
- Uses the CLIP image encoder.
- """
+ Uses the CLIP image encoder.
+ """
+
def __init__(
- self,
- model,
- jit=False,
- device='cuda' if torch.cuda.is_available() else 'cpu',
- antialias=False,
- ):
+ self,
+ model,
+ jit=False,
+ device="cuda" if torch.cuda.is_available() else "cpu",
+ antialias=False,
+ ):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
self.antialias = antialias
- self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
- self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
+ self.register_buffer(
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
+ )
+ self.register_buffer(
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
+ )
def preprocess(self, x):
# normalize to [0,1]
- x = kornia.geometry.resize(x, (224, 224),
- interpolation='bicubic',align_corners=True,
- antialias=self.antialias)
- x = (x + 1.) / 2.
+ x = kornia.geometry.resize(
+ x,
+ (224, 224),
+ interpolation="bicubic",
+ align_corners=True,
+ antialias=self.antialias,
+ )
+ x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
@@ -230,5 +350,6 @@ def forward(self, x):
if __name__ == "__main__":
from ldm.util import count_params
- model = FrozenCLIPEmbedder()
- count_params(model, verbose=True)
\ No newline at end of file
+
+ model = PersonalizedCLIPEmbedder()
+ count_params(model, verbose=True)
diff --git a/reference_images/aivazovsky/a-corner-of-constantinople-from-the-sea-by-moonlig-1878-peiz.jpg b/reference_images/aivazovsky/a-corner-of-constantinople-from-the-sea-by-moonlig-1878-peiz.jpg
new file mode 100644
index 0000000..1986c89
Binary files /dev/null and b/reference_images/aivazovsky/a-corner-of-constantinople-from-the-sea-by-moonlig-1878-peiz.jpg differ
diff --git a/reference_images/aivazovsky/full-moon-at-night-konstantinovich-watercolor.jpg b/reference_images/aivazovsky/full-moon-at-night-konstantinovich-watercolor.jpg
new file mode 100644
index 0000000..32346f0
Binary files /dev/null and b/reference_images/aivazovsky/full-moon-at-night-konstantinovich-watercolor.jpg differ
diff --git a/reference_images/aivazovsky/genoese-towers-in-the-black-sea-ivan-aivazovsky.jpg b/reference_images/aivazovsky/genoese-towers-in-the-black-sea-ivan-aivazovsky.jpg
new file mode 100644
index 0000000..ed03785
Binary files /dev/null and b/reference_images/aivazovsky/genoese-towers-in-the-black-sea-ivan-aivazovsky.jpg differ
diff --git a/scripts/gen_aesthetic_embeddings.py b/scripts/gen_aesthetic_embeddings.py
new file mode 100644
index 0000000..0f4e239
--- /dev/null
+++ b/scripts/gen_aesthetic_embeddings.py
@@ -0,0 +1,26 @@
+import clip
+import glob
+from PIL import Image
+import torch
+import tqdm
+
+# Just put your images in a folder inside reference_images/
+aesthetic_style = "aivazovsky"
+image_paths = glob.glob(f"reference_images/{aesthetic_style}/*")
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model, preprocess = clip.load("ViT-L/14", device=device)
+
+
+with torch.no_grad():
+ embs = []
+ for path in tqdm.tqdm(image_paths):
+ image = preprocess(Image.open(path)).unsqueeze(0).to(device)
+ emb = model.encode_image(image)
+ embs.append(emb.cpu())
+
+ embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
+
+ # The generated embedding will be located here
+ torch.save(embs, f"aesthetic_embeddings/{aesthetic_style}.pt")
diff --git a/scripts/txt2img.py b/scripts/txt2img.py
index 59c16a1..cca71b7 100644
--- a/scripts/txt2img.py
+++ b/scripts/txt2img.py
@@ -18,7 +18,9 @@
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
-from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
from transformers import AutoFeatureExtractor
@@ -68,7 +70,7 @@ def load_model_from_config(config, ckpt, verbose=False):
def put_watermark(img, wm_encoder=None):
if wm_encoder is not None:
img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
- img = wm_encoder.encode(img, 'dwtDct')
+ img = wm_encoder.encode(img, "dwtDct")
img = Image.fromarray(img[:, :, ::-1])
return img
@@ -77,7 +79,7 @@ def load_replacement(x):
try:
hwc = x.shape
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
- y = (np.array(y)/255.0).astype(x.dtype)
+ y = (np.array(y) / 255.0).astype(x.dtype)
assert y.shape == x.shape
return y
except Exception:
@@ -85,8 +87,12 @@ def load_replacement(x):
def check_safety(x_image):
- safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
- x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
+ safety_checker_input = safety_feature_extractor(
+ numpy_to_pil(x_image), return_tensors="pt"
+ )
+ x_checked_image, has_nsfw_concept = safety_checker(
+ images=x_image, clip_input=safety_checker_input.pixel_values
+ )
assert x_checked_image.shape[0] == len(has_nsfw_concept)
for i in range(len(has_nsfw_concept)):
if has_nsfw_concept[i]:
@@ -102,23 +108,23 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
- help="the prompt to render"
+ help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
- default="outputs/txt2img-samples"
+ default="outputs/txt2img-samples",
)
parser.add_argument(
"--skip_grid",
- action='store_true',
+ action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
- action='store_true',
+ action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
@@ -129,17 +135,17 @@ def main():
)
parser.add_argument(
"--plms",
- action='store_true',
+ action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--laion400m",
- action='store_true',
+ action="store_true",
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
- action='store_true',
+ action="store_true",
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
@@ -204,7 +210,7 @@ def main():
parser.add_argument(
"--config",
type=str,
- default="configs/stable-diffusion/v1-inference.yaml",
+ default="configs/stable-diffusion/v1-inference-aesthetic.yaml",
help="path to config which constructs model",
)
parser.add_argument(
@@ -224,7 +230,25 @@ def main():
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
- default="autocast"
+ default="autocast",
+ )
+ parser.add_argument(
+ "--aesthetic_steps",
+ type=int,
+ help="number of steps for the aesthetic personalization",
+ default=10,
+ )
+ parser.add_argument(
+ "--aesthetic_lr",
+ type=int,
+ help="learning rate for the aesthetic personalization",
+ default=0.0001,
+ )
+ parser.add_argument(
+ "--aesthetic_embedding",
+ type=str,
+ help="aesthetic embedding file",
+ default="aesthetic_embeddings/sac_8plus.pt",
)
opt = parser.parse_args()
@@ -237,6 +261,14 @@ def main():
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
+
+ # Override config with personalization arguments
+ config.model.params.cond_stage_config.params.T = opt.aesthetic_steps
+ config.model.params.cond_stage_config.params.lr = opt.aesthetic_lr
+ config.model.params.cond_stage_config.params.aesthetic_embedding_path = (
+ opt.aesthetic_embedding
+ )
+
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@@ -250,10 +282,12 @@ def main():
os.makedirs(opt.outdir, exist_ok=True)
outpath = opt.outdir
- print("Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)...")
+ print(
+ "Creating invisible watermark encoder (see https://github.com/ShieldMnt/invisible-watermark)..."
+ )
wm = "StableDiffusionV1"
wm_encoder = WatermarkEncoder()
- wm_encoder.set_watermark('bytes', wm.encode('utf-8'))
+ wm_encoder.set_watermark("bytes", wm.encode("utf-8"))
batch_size = opt.n_samples
n_rows = opt.n_rows if opt.n_rows > 0 else batch_size
@@ -275,9 +309,11 @@ def main():
start_code = None
if opt.fixed_code:
- start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
+ start_code = torch.randn(
+ [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device
+ )
- precision_scope = autocast if opt.precision=="autocast" else nullcontext
+ precision_scope = autocast if opt.precision == "autocast" else nullcontext
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
@@ -292,30 +328,43 @@ def main():
prompts = list(prompts)
c = model.get_learned_conditioning(prompts)
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
- samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
- conditioning=c,
- batch_size=opt.n_samples,
- shape=shape,
- verbose=False,
- unconditional_guidance_scale=opt.scale,
- unconditional_conditioning=uc,
- eta=opt.ddim_eta,
- x_T=start_code)
+ samples_ddim, _ = sampler.sample(
+ S=opt.ddim_steps,
+ conditioning=c,
+ batch_size=opt.n_samples,
+ shape=shape,
+ verbose=False,
+ unconditional_guidance_scale=opt.scale,
+ unconditional_conditioning=uc,
+ eta=opt.ddim_eta,
+ x_T=start_code,
+ )
x_samples_ddim = model.decode_first_stage(samples_ddim)
- x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
- x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+ x_samples_ddim = torch.clamp(
+ (x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0
+ )
+ x_samples_ddim = (
+ x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy()
+ )
- x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
+ # x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim)
+ x_checked_image = x_samples_ddim
- x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)
+ x_checked_image_torch = torch.from_numpy(
+ x_checked_image
+ ).permute(0, 3, 1, 2)
if not opt.skip_save:
for x_sample in x_checked_image_torch:
- x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
+ x_sample = 255.0 * rearrange(
+ x_sample.cpu().numpy(), "c h w -> h w c"
+ )
img = Image.fromarray(x_sample.astype(np.uint8))
img = put_watermark(img, wm_encoder)
- img.save(os.path.join(sample_path, f"{base_count:05}.png"))
+ img.save(
+ os.path.join(sample_path, f"{base_count:05}.png")
+ )
base_count += 1
if not opt.skip_grid:
@@ -324,20 +373,21 @@ def main():
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
- grid = rearrange(grid, 'n b c h w -> (n b) c h w')
+ grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
- grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
+ grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
img = Image.fromarray(grid.astype(np.uint8))
img = put_watermark(img, wm_encoder)
- img.save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
+ img.save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
- print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
- f" \nEnjoy.")
+ print(
+ f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy."
+ )
if __name__ == "__main__":