Skip to content

PyTorch implementation of 'CycleGAN' (Zhu et al., 2017) and training it on 6 datasets

Notifications You must be signed in to change notification settings

KimRass/CycleGAN

Repository files navigation

'CycleGAN' (Zhu et al., 2017) implementation from scratch in PyTorch

Paper Reading

How to Use

Image Generation

# For example,
python3 generate_images.py\
    --ds_name="monet2photo"\
    --data_dir=".../monet2photo/"\
    --x_or_y="x"\
    --ckpt_path=".../monet_to_photo.pth"\
    --n_cpus=0 # Optional

Pre-trained Models and Generated Images

Pre-trained model Generated images on test set
Monet to photo cyclegan_monet_to_photo.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/monet_to_photo
Photo to Monet cyclegan_photo_to_monet.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/photo_to_monet
Vangogh to photo cyclegan_vangogh_to_photo.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/vangogh_to_photo
Photo to Vangogh cyclegan_photo_to_vangogh.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/photo_to_vangogh
Cezanne to photo cyclegan_cezanne_to_photo.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/cezanne_to_photo
Photo to Cezanne cyclegan_photo_to_cezanne.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/photo_to_cezanne
Ukiyo-e to photo cyclegan_ukiyoe_to_photo.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/ukiyoe_to_photo
Photo to ukiyo-e cyclegan_photo_to_ukiyoe.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/photo_to_ukiyoe
Horse to zebra cyclegan_horse_to_zebra.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/horse_to_zebra
Zebra to horse cyclegan_zebra_to_horse.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/zebra_to_horse
Summer to winter cyclegan_summer_to_winter.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/summer_to_winter
Winter to summer cyclegan_winter_to_summer.pth https://github.com/KimRass/CycleGAN/tree/main/generated_images/winter_to_summer
  • 'horse2zebra' dataset의 경우 이미지 중 일부가 화질이 좋지 않거나 이미지의 비율이 왜곡되어 있고 이미지 내에서 말 또는 얼룩말이 차지하는 영역이 매우 작은 등 개인적으로 데이터셋의 품질 자체가 좋지 않다고 생각합니다. 좀 더 대량의, 고품질 데이터셋을 구축하여 다시 학습시켜보면 더욱 우수한 모델을 얻를 수 있을 것입니다.

Monet to Photo

Photo to Monet

Vangogh to Photo

Photo to Vangogh

Cezanne to Photo

Photo to Cezanne

Ukiyoe to Photo

Photo to Ukiyoe

Horse to Zebra

Zebra to Horse

Implementation Details

  • 논문만 가지고는 정확히 알기 어려운 부분 또는 논문과 공식 저장소가 서로 다른 부분이 많은데, 공식 저장소를 기준으로 구현했습니다.

Merging Optimizers

    • discriminators (Dx와 Dy)와 generators (Gx와 Gy)의 objective는 방향성이 서로 충돌하지만 (adversarial training) Dx와 Dy 그리고 Gx와 Gy는 서로 objective의 방향성이 동일하므로, Dx의 Optimizer와 Dy의 Optimizer를 하나로 합치고, Gx의 Optimizer와 Gy의 Optimizer를 하나로 합쳤습니다.
  • As-is:
    disc_x_optim = Adam(params=disc_x.parameters(), lr=lr)
    disc_y_optim = Adam(params=disc_y.parameters(), lr=lr)
    gen_x_optim = Adam(params=gen_x.parameters(), lr=lr)
    gen_y_optim = Adam(params=gen_y.parameters(), lr=lr)
  • To-be:
    disc_optim = Adam(params=list(disc_x.parameters()) + list(disc_y.parameters()), lr=lr)
    gen_optim = Adam(params=list(gen_x.parameters()) + list(gen_y.parameters()), lr=lr)

Padding Mode

  • 논문에서는 모든 padding에 대해서 padding_mode="reflect"를 사용한 것처럼 쓰여 있으나 공식 저장소를 보면 padding_mode="zeros"padding_mode="reflect"를 혼용하고 있어 이를 따랐습니다.

Image Pairing

  • 이미지의 집합 X와 Y의 크기가 서로 다르므로 만약 X의 크기가 Y의 크기보다 크다면 1 epoch 동안 X의 이미지가 한 번씩 모델에 입력으로 들어갈 때 Y의 이미지는 한 번 이상씩 모델에 입력으로 들어가게 됩니다. 즉 X의 크기가 데이터의 크기가 됩니다. 이 점을 간과해 데이터의 크기를 Y의 크기와 같게 했고 X와 Y의 각 원소를 정해진대로 1:1 대응이 되도록 코드를 짰었으나 이를 수정했습니다.
  • As-is:
    x_path = self.x_paths[idx]
    y_path = self.y_paths[idx]
  • To-be:
    if self.x_len >= self.y_len:
        x_path = self.x_paths[idx]
        y_path = random.choice(self.y_paths)
    else:
        y_path = self.y_paths[idx]
        x_path = random.choice(self.x_paths)

LSGANs

  • 논문에 따르면 objective로서 'negative log likelihood' (GAN_CRIT = nn.BCEWithLogitsLoss()) 대신에 'least-squares' (GAN_CRIT = nn.MSELoss())를 사용합니다. [1] 전자를 사용할 경우 금방 mode collapse가 발생하는 것을 관찰할 수 있었습니다.

Gatys et al. (2016)

References

About

PyTorch implementation of 'CycleGAN' (Zhu et al., 2017) and training it on 6 datasets

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published