Skip to content

Commit 8c4c9cc

Browse files
author
Ubuntu
committed
Make it PT Only compatible
1 parent 5f2e61c commit 8c4c9cc

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

docker/Dockerfile

+4-4
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ RUN if [ "${cloud_build}" = true ]; then github_branch="${release_version}" && \
4646
cd / && \
4747
git clone -b "${github_branch}" --recursive https://github.com/pytorch-tpu/examples tpu-examples; fi
4848

49-
# RUN cd /pytorch && bash xla/scripts/build_torch_wheels.sh ${python_version} ${release_version}
49+
RUN cd /pytorch && bash xla/scripts/build_torch_wheels.sh ${python_version} ${release_version}
5050

5151
# Use conda environment on startup or when running scripts.
52-
# RUN echo "conda activate pytorch" >> ~/.bashrc
53-
# RUN echo "export TF_CPP_LOG_THREAD_ID=1" >> ~/.bashrc
54-
# ENV PATH /root/anaconda3/envs/pytorch/bin/:/root/bin:$PATH
52+
RUN echo "conda activate pytorch" >> ~/.bashrc
53+
RUN echo "export TF_CPP_LOG_THREAD_ID=1" >> ~/.bashrc
54+
ENV PATH /root/anaconda3/envs/pytorch/bin/:/root/bin:$PATH
5555

5656
# Define entrypoint and cmd
5757
COPY docker/docker-entrypoint.sh /usr/local/bin

test/__init__.py

Whitespace-only changes.

test/test_train_mp_imagenet_torch_amp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import args_parse
1+
import test.args_parse as args_parse
22

33
SUPPORTED_MODELS = [
44
"alexnet",
@@ -56,7 +56,7 @@
5656

5757
import os
5858

59-
import schedulers
59+
import test.schedulers as schedulers
6060
import numpy as np
6161
import torch
6262
import torch.nn as nn
@@ -67,7 +67,7 @@
6767
import torch_xla.utils.utils as xu
6868
import torch_xla.test.test_utils as test_utils
6969
import torch_xla.core.xla_model as xm
70-
from classification_benchmark_constants import DEFAULT_KWARGS, MODEL_SPECIFIC_DEFAULTS
70+
from test.classification_benchmark_constants import DEFAULT_KWARGS, MODEL_SPECIFIC_DEFAULTS
7171
from torch.cuda.amp import GradScaler, autocast
7272

7373
default_value_dict = MODEL_SPECIFIC_DEFAULTS.get(FLAGS.model, DEFAULT_KWARGS)

0 commit comments

Comments
 (0)