This is the training code for the Jax/Flax implementation of Few-shot Image Generation via Cross-domain Correspondence.
- Getting Started
- Preparing Datasets for Training
- Training
- Checkpoints
- Generating Images
- References
- License
You will need Python 3.7 or later.
- Clone the repository:
> git clone https://github.com/matthias-wright/flaxmodels.git
- Go into the directory:
> cd flaxmodels/training/few_shot_gan_adaption
- Install Jax with CUDA.
- Install requirements:
> pip install -r requirements.txt
Before training, the images should be stored in a TFRecord dataset. The TFRecord format stores your data as a sequence of bytes, which allows for fast data loading.
Alternatively, you can also use tfds.folder_dataset.ImageFolder on the image directory directly but you will have to replace the tf.data.TFRecordDataset
in data_pipeline.py
with tfds.folder_dataset.ImageFolder
(see this thread for more info).
- Download dataset from here.
- Put all images into a directory:
/path/to/image_dir/ 0.jpg 1.jpg 2.jpg 4.jpg ...
- Create TFRecord dataset:
> python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord
--image_dir
is the path to the image directory.
--data_dir
is the path where the TFRecord dataset is stored.
Download checkpoint of source model:
> wget https://www.dropbox.com/s/hyh1k8ixtzy24ye/ffhq_256x256.pickle\?dl\=1 -O ffhq_256x256.pickle
Start training:
> CUDA_VISIBLE_DEVICES=a,b,c,d python main.py --data_dir /path/to/tfrecord --source_ckpt_path ffhq_256x256.pickle
Here a
, b
, c
, d
are the GPU indices. Multi GPU training (data parallelism) works by default and will automatically use all the devices that you make visible.
I use Weights & Biases for logging but you can simply replace it with the logging method of your choice. The logging happens all in the training loop implemented in training.py
. To use logging with Weights & Biases, use --wand
.
By default, every 1000
training steps the FID score is evaluated for 10.000
images. The checkpoint with the highest FID score is saved. You can change evaluation frequency using the --eval_fid_every
argument and the number of images to evaluate the FID score on using --num_fid_images
.
You can disable the FID score evaluation using --disable_fid
. In that case, a checkpoint will be saved every 2000
steps (can be changed using --save_every
).
- Sketches (357,2 MB)
- Amedeo Modigliani (357,2 MB)
- Babies (357,2 MB)
- Otto Dix (357,2 MB)
- Rafael (357,2 MB)
import jax
import numpy as np
import dill as pickle
from PIL import Image
import flaxmodels as fm
ckpt = pickle.load(open('sketches.pickle', 'rb'))
params = ckpt['params_ema_G']
generator = fm.few_shot_gan_adaption.Generator()
# Seed
key = jax.random.PRNGKey(0)
# Input noise
z = jax.random.normal(key, shape=(4, 512))
# Generate images
images, _ = generator.apply(params, z, truncation_psi=0.5, train=False, noise_mode='const')
# Normalize images to be in range [0, 1]
images = (images - np.min(images)) / (np.max(images) - np.min(images))
# Save images
for i in range(images.shape[0]):
Image.fromarray(np.uint8(images[i] * 255)).save(f'image_{i}.jpg')