Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question regarding CIFAR10 standard deviation #129

Open
ssgosh opened this issue Mar 11, 2021 · 0 comments
Open

Question regarding CIFAR10 standard deviation #129

ssgosh opened this issue Mar 11, 2021 · 0 comments

Comments

@ssgosh
Copy link

ssgosh commented Mar 11, 2021

I computed the mean and standard deviation of cifar dataset (train). While the means match, the standard deviations do not:

mean = [0.4913996458053589, 0.48215845227241516, 0.44653093814849854]
std = [0.2470322549343109, 0.24348513782024384, 0.26158788800239563]

Relevant line from main.py:

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

I am wondering what is the cause of this discrepancy. Relevant code that I used to compute these stats:

    datasets = [ torchvision.datasets.CIFAR10('./data', download=True, train=True, transform=torchvision.transforms.ToTensor()) ]
    channels = 3
    total = torch.zeros(channels)
    total_sq = torch.zeros(channels)
    num = 0
    for ds in datasets:
        dl = DataLoader(ds, batch_size=1000)
        for images, targets in dl:
            total += torch.sum(images, (0, 2, 3))
            total_sq += torch.sum(images * images, (0, 2, 3))
            num += images.shape[0] * images.shape[2] * images.shape[3]

    mean = total / num
    mean_sq = total_sq / num
    std = torch.sqrt(mean_sq - mean**2)

    print(f'mean = {mean.tolist()}, std = {std.tolist()}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant