diff --git a/.gitignore b/.gitignore index 482942b..66cc806 100644 --- a/.gitignore +++ b/.gitignore @@ -43,8 +43,10 @@ pip-delete-this-directory.txt htmlcov/ .tox/ .coverage +.coveragerc .coverage.* .cache +.settings nosetests.xml coverage.xml *,cover diff --git a/sknn/__init__.py b/sknn/__init__.py index 2b3ae56..a31ad2d 100644 --- a/sknn/__init__.py +++ b/sknn/__init__.py @@ -19,7 +19,8 @@ def __init__(self): def configure(self, flags): if self.configured is True: return - + self.configured = True + if 'theano' in sys.modules: self.log.warning('Theano was already imported and cannot be reconfigured.') return @@ -30,7 +31,6 @@ def configure(self, flags): import theano cuda.setLevel(logging.WARNING) - self.configured = True try: import theano.sandbox.cuda as cd self.log.info('Using device gpu%i: %s', cd.active_device_number(), cd.active_device_name()) diff --git a/sknn/tests/test_backend.py b/sknn/tests/test_backend.py index 1f16304..6bbe781 100644 --- a/sknn/tests/test_backend.py +++ b/sknn/tests/test_backend.py @@ -1,8 +1,10 @@ import unittest from nose.tools import (assert_in, assert_equal) +import io import os import sys +import logging import sknn @@ -12,13 +14,29 @@ class TestBackendPseudoModule(unittest.TestCase): def setUp(self): if 'THEANO_FLAGS' in os.environ: del os.environ['THEANO_FLAGS'] + + self.removed = {} for name in sys.modules.keys(): if name.startswith('theano'): + self.removed[name] = sys.modules[name] del sys.modules[name] sys.modules['sknn.backend'].configured = False + self.buf = io.StringIO() + self.hnd = logging.StreamHandler(self.buf) + logging.getLogger('sknn').addHandler(self.hnd) + logging.getLogger().setLevel(logging.WARNING) + + def tearDown(self): + for name, module in self.removed.items(): + sys.modules[name] = module + logging.getLogger('sknn').removeHandler(self.hnd) + def test_TheanoWarning(self): - pass + import theano + from sknn.backend import cpu + assert_equal('Theano was already imported and cannot be reconfigured.\n', + self.buf.getvalue()) def _check(self, flags): assert_in('THEANO_FLAGS', os.environ)