Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data_pipeline.py needs more changes than suggested in README to support ImageFolder datasets #8

Closed
josephrocca opened this issue Oct 8, 2021 · 5 comments

Comments

@josephrocca
Copy link
Contributor

josephrocca commented Oct 8, 2021

Some problems I ran into:

  • I wasn't able to get tfds.ImageFolder working with a "flat" folder of images. I had to nest a dummy label folder inside a dummy split folder. I followed the instructions here: https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder
  • There doesn't seem to be a num_examples property in tfds.core.DatasetInfo, so I had to use builder.info.splits['fake_split'].num_examples where fake_split is the name of my dummy split folder. It does look like there's a total_num_examples property, but I'm not sure how to access it - maybe it's a private field (though I'm not sure if those are possible in Python)?
  • I had to edit pre_process because it was expecting protobufs instead of {image, label} objects.

Note that the reason I am using the ImageFolder approach is because the tfrecords approach blew my 3GB dataset up to 200GB, since I think it's storing the raw tensor data? I'm new to this, but it seems like it'd make more sense to just store the data in jpg format since jpg decoding is so fast? That said, even if the tfrecords approach used a reasonable amount of space, I'd probably still prefer to store the ImageFolder approach since it just seems nicer and more portable. Even better, from my (newbie) perspective, would be the ability to load a tar of images with any internal directory structure.

Below is my new data_pipeline.py so far. It seems to work okay now, but I haven't got training to work yet as I'm still debugging some stuff. Will update this post if I run into any more problems with data_pipeline.py.

import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import flax
import numpy as np
from PIL import Image
import os
from typing import Sequence
from tqdm import tqdm
import json
from tqdm import tqdm


def prefetch(dataset, n_prefetch):
    # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
    ds_iter = iter(dataset)
    ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
                  ds_iter)
    if n_prefetch:
        ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
    return ds_iter


def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
    """

    Args:
        data_dir (str): Root directory of the dataset.
        img_size (int): Image size for training.
        img_channels (int): Number of image channels.
        num_classes (int): Number of classes, 0 for no classes.
        num_devices (int): Number of devices.
        batch_size (int): Batch size (per device).
        shuffle_buffer (int): Buffer used for shuffling the dataset.

    Returns:
        (tf.data.Dataset): Dataset.
    """

    def pre_process(example):
        # feature = {'height': tf.io.FixedLenFeature([], tf.int64),
        #            'width': tf.io.FixedLenFeature([], tf.int64),
        #            'channels': tf.io.FixedLenFeature([], tf.int64),
        #            'image': tf.io.FixedLenFeature([], tf.string),
        #            'label': tf.io.FixedLenFeature([], tf.int64)}
        # example = tf.io.parse_single_example(serialized_example, feature)

        # height = tf.cast(example['height'], dtype=tf.int64)
        # width = tf.cast(example['width'], dtype=tf.int64)
        # channels = tf.cast(example['channels'], dtype=tf.int64)

        # image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
        # image = tf.reshape(image, shape=[height, width, channels])

        image = example['image']

        image = tf.cast(image, dtype='float32')
        image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
        image = tf.image.random_flip_left_right(image)
        
        image = (image - 127.5) / 127.5
        
        label = tf.one_hot(example['label'], num_classes)
        return {'image': image, 'label': label}

    def shard(data):
        # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
        # because the first dimension will be mapped across devices using jax.pmap
        data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
        data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
        return data

    # print('Loading TFRecord...')
    # with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
    #    dataset_info = json.load(fin)

    # ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
    # ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))

    builder = tfds.ImageFolder(data_dir)
    print(builder.info)
    ds = builder.as_dataset(split='fake_split', shuffle_files=True)
    num_examples = builder.info.splits['fake_split'].num_examples
    dataset_info = {'num_examples': num_examples, 'num_classes': 1}
    
    ds = ds.shuffle(min(num_examples, shuffle_buffer))
    ds = ds.map(pre_process, tf.data.AUTOTUNE)
    ds = ds.batch(batch_size * num_devices, drop_remainder=True)
    ds = ds.map(shard, tf.data.AUTOTUNE)
    ds = ds.prefetch(1)
    return ds, dataset_info
@matthias-wright
Copy link
Owner

Hi @josephrocca, great to hear that you are using the repository for training!
The TFRecord dataset will be larger than your image folder but an increase from 3 GB to 200 GB seems too large.
As an example, the FFHQ image folder with resolution 1024x1024 in PNG format is 90 GB and the corresponding TFRecord dataset is 189 GB.
Can you give some details about your images (number of images, format, resolution, etc)?
I really like the TFRecord format because of the efficient data loading. I am sorry that I did not give enough details for using tfds.ImageFolder though.

@josephrocca
Copy link
Contributor Author

great to hear that you are using the repository for training!

Thanks for your work on this repo!

To reproduce, download the image below and name it img_0.jpg, then put it in a new empty folder and open a terminal in that folder, then run this:

for i in {1..40000}; do cp img_0.jpg "img_$i.jpg"; done

Then, per the stylgan2 training readme, run:

python dataset_utils/images_to_tfrecords.py --image_dir /path/to/image_dir/ --data_dir /path/to/tfrecord

It goes from about 3.5GB to more than 200GB.

img_0

@matthias-wright
Copy link
Owner

matthias-wright commented Oct 15, 2021

Hi @josephrocca, thanks for the info!
You are using JPEG images, which are highly compressed. The image you uploaded is 92.5 kilobyte as a JPEG. However, if you represent the same image as a uint8 tensor, the tensor uses 5.84 megabyte.
This is because: (1162 * 1758 * 3) / (1024 * 1024) = 5.84448623657.
So those 40000 images, which take up 3.5 GB stored as JPEG, actually take up 228.125 GB as uint8 tensors (assuming they all have the same size).
That is where the size of the TFRecord dataset is coming from.
Are you actually training on resolution 1162 x 1758 ? If not I would suggest to resize the images before storing them as a TFRecord dataset. That will reduce the size. Hope that helps!

@josephrocca
Copy link
Contributor Author

@matthias-wright Yep I suspected it was because they're being stored as raw data. Seems like a bad idea for large image datasets though, given the huge size inflation from a jpg? Since JPG decoders on modern hardware are really fast, would the jpg decoding step actually be a bottleneck in training?

In any case, the image folder approach works nicely for me with the changes I mentioned in the original post. This issue was more about some changes needed to the stylegan2 training code (specifically data_pipeline.py) to match the instructions provided in the readme, since the current ImageFolder instructions don't work. Perhaps adding a link in the readme to my above comment would do for now, but I'll leave it up to you.

Thanks again for your work on this repo! (Crossing my fingers that you'll work on stylegan3-flax next :P)

@matthias-wright
Copy link
Owner

I agree that this is not very efficient for storage purposes. I guess most people are willing to trade of training speed for storage capacity to some extent. Great to hear that you got it working! I linked from the readme to this thread, thanks for that! Haha, I hope that I will find some time to work on stylegan3, but not sure yet.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants