diff --git a/tests/unit/alexnet_model.py b/tests/unit/alexnet_model.py index ab51a4630b7f..b94bd7052a9f 100755 --- a/tests/unit/alexnet_model.py +++ b/tests/unit/alexnet_model.py @@ -114,7 +114,7 @@ def cifar_trainset(fp16=False): dist.barrier() if local_rank != 0: dist.barrier() - trainset = torchvision.datasets.CIFAR10(root='/tmp/cifar10-data', + trainset = torchvision.datasets.CIFAR10(root='/blob/cifar10-data', train=True, download=True, transform=transform)