Skip to content

Commit

Permalink
Hot fix for MNIST Download Error.
Browse files Browse the repository at this point in the history
#63
This will be fixed in the next release of torch vision.
  • Loading branch information
maximilianreimer committed Mar 26, 2021
1 parent 411ec59 commit 4dc6f2a
Showing 1 changed file with 7 additions and 26 deletions.
33 changes: 7 additions & 26 deletions dacbench/envs/sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,11 @@
from backpack.extensions import BatchGrad
from gym.utils import seeding
from torchvision import datasets, transforms
import urllib
from dacbench import AbstractEnv

warnings.filterwarnings("ignore")


def set_header_for(url, filename):
opener = urllib.request.URLopener()
opener.addheader(
"User-Agent",
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_9_3) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/35.0.1916.47 Safari/537.36",
)
opener.retrieve(url, f"./{filename}")


class SGDEnv(AbstractEnv):
"""
Environment to control the learning rate of adam
Expand Down Expand Up @@ -234,22 +224,13 @@ def reset(self):
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

set_header_for(
"http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
"train-images-idx3-ubyte.gz",
)
set_header_for(
"http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
"train-labels-idx1-ubyte.gz",
)
set_header_for(
"http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
"t10k-images-idx3-ubyte.gz",
)
set_header_for(
"http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
"t10k-labels-idx1-ubyte.gz",
)
# hot fix for https://github.com/pytorch/vision/issues/3549
# If fix is available in stable version (0.9.1), we should update and be removed this.
new_mirror = "https://ossci-datasets.s3.amazonaws.com/mnist"
datasets.MNIST.resources = [
("/".join([new_mirror, url.split("/")[-1]]), md5)
for url, md5 in datasets.MNIST.resources
]

train_dataset = datasets.MNIST(
"../data", train=True, download=True, transform=transform
Expand Down

0 comments on commit 4dc6f2a

Please sign in to comment.