Skip to content

Commit 1d3a65e

Browse files
committed
Add a wrapper module to import Tensorflow.
This is to workaround a known issue with packaging Tensorflow as a pip dependency in bazel. See: github.com/bazel-contrib/rules_python/issues/71 Signed-off-by: format 2020.01.12 <github.com/ChrisCummins/format>
1 parent c990998 commit 1d3a65e

File tree

3 files changed

+55
-8
lines changed

3 files changed

+55
-8
lines changed

deeplearning/clgen/models/helper.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
import tensorflow as tf
2-
import tensorflow.contrib.seq2seq as seq2seq
31
from tensorflow.python.ops import math_ops
42
from tensorflow.python.ops.distributions import categorical
53
from tensorflow.python.util import nest
64

5+
import third_party.py.tensorflow.tf.contrib.seq2seq as seq2seq
6+
from third_party.py.tensorflow import tf
7+
78

89
class CustomInferenceHelper(seq2seq.TrainingHelper):
910
"""An inference helper that takes a seed text"""

deeplearning/clgen/models/tensorflow_backend.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def __init__(self, *args, **kwargs):
8787

8888
# Create the summary writer, shared between Train() and
8989
# _EndOfEpochTestSample().
90-
import tensorflow as tf
90+
from third_party.py.tensorflow import tf
9191

9292
tensorboard_dir = f"{self.cache.path}/tensorboard"
9393
app.Log(
@@ -124,8 +124,8 @@ def InitTfGraph(
124124
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
125125

126126
# Deferred importing of TensorFlow.
127-
import tensorflow as tf
128-
import tensorflow.contrib.seq2seq as seq2seq
127+
from third_party.py.tensorflow import tf
128+
import third_party.py.tensorflow.tf.contrib.seq2seq as seq2seq
129129
from tensorflow.contrib import rnn
130130
from deeplearning.clgen.models import helper
131131

@@ -483,7 +483,7 @@ def _EndOfEpochTestSample(
483483
self, corpus, sampler: samplers.Sampler, step: int, epoch_num: int
484484
):
485485
"""Run sampler"""
486-
import tensorflow as tf
486+
from third_party.py.tensorflow import tf
487487

488488
atomizer = corpus.atomizer
489489
sampler.Specialize(atomizer)
@@ -545,7 +545,7 @@ def InitSampling(
545545
self, sampler: samplers.Sampler, seed: typing.Optional[int] = None
546546
) -> None:
547547
"""Initialize model for sampling."""
548-
import tensorflow as tf
548+
from third_party.py.tensorflow import tf
549549

550550
# Delete any previous sampling session.
551551
if self.inference_tf:
@@ -622,7 +622,7 @@ def SampleNextIndices(self, sampler: samplers.Sampler, done: np.ndarray):
622622
return generated
623623

624624
def RandomizeSampleState(self) -> None:
625-
import tensorflow as tf
625+
from third_party.py.tensorflow import tf
626626

627627
self.inference_state = [
628628
tf.nn.rnn_cell.LSTMStateTuple(

third_party/py/tensorflow/tf.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Import Tensorflow.
2+
3+
This module is a drop-in replacement for regular tensorflow. Replace:
4+
5+
import tensorflow as tf
6+
7+
with:
8+
9+
from third_party.py.tensorflow import tf
10+
11+
This wrapper is required to workaround a known bug with packaging Tensorflow
12+
as a pip dependency with bazel. See:
13+
github.com/bazelbuild/rules_python/issues/71
14+
"""
15+
import importlib
16+
import pathlib
17+
import sys
18+
19+
try:
20+
# Try importing Tensorflow the vanilla way. This will succeed once
21+
# github.com/bazelbuild/rules_python/issues/71 is fixed.
22+
import tensorflow
23+
except (ImportError, ModuleNotFoundError):
24+
# That failed, so see if there is a system install of Tensorflow that we
25+
# can trick python into importing. This should succeed in the
26+
# chriscummins/phd_base_tf_cpu docker image.
27+
PYTHON_SITE_PACKAGES = pathlib.Path("/usr/local/lib/python3.7/site-packages")
28+
29+
try:
30+
if pathlib.Path(PYTHON_SITE_PACKAGES / "tensorflow").is_dir():
31+
sys.path.insert(0, "/usr/local/lib/python3.7/site-packages")
32+
33+
import tensorflow
34+
else:
35+
raise ModuleNotFoundError
36+
except (ImportError, ModuleNotFoundError):
37+
# That failed, so a final hail mary let's try importing the module directly.
38+
tensorflow = importlib.import_module(
39+
"tensorflow", PYTHON_SITE_PACKAGES / "tensorflow",
40+
)
41+
42+
# Import Tensorflow into this module's namespace.
43+
from tensorflow import *
44+
45+
# Pretend that we've imported the regular Tensorflow.
46+
__file__ = tensorflow.__file__

0 commit comments

Comments
 (0)