Skip to content
This repository has been archived by the owner on Jul 10, 2021. It is now read-only.

Commit

Permalink
Another check to test for Theano import warning, fix for other broken…
Browse files Browse the repository at this point in the history
… tests by reimporting theano.
  • Loading branch information
alexjc committed May 7, 2015
1 parent e78b46b commit 7085c25
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ pip-delete-this-directory.txt
htmlcov/
.tox/
.coverage
.coveragerc
.coverage.*
.cache
.settings
nosetests.xml
coverage.xml
*,cover
Expand Down
4 changes: 2 additions & 2 deletions sknn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self):
def configure(self, flags):
if self.configured is True:
return

self.configured = True

if 'theano' in sys.modules:
self.log.warning('Theano was already imported and cannot be reconfigured.')
return
Expand All @@ -30,7 +31,6 @@ def configure(self, flags):
import theano
cuda.setLevel(logging.WARNING)

self.configured = True
try:
import theano.sandbox.cuda as cd
self.log.info('Using device gpu%i: %s', cd.active_device_number(), cd.active_device_name())
Expand Down
20 changes: 19 additions & 1 deletion sknn/tests/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import unittest
from nose.tools import (assert_in, assert_equal)

import io
import os
import sys
import logging

import sknn

Expand All @@ -12,13 +14,29 @@ class TestBackendPseudoModule(unittest.TestCase):
def setUp(self):
if 'THEANO_FLAGS' in os.environ:
del os.environ['THEANO_FLAGS']

self.removed = {}
for name in sys.modules.keys():
if name.startswith('theano'):
self.removed[name] = sys.modules[name]
del sys.modules[name]
sys.modules['sknn.backend'].configured = False

self.buf = io.StringIO()
self.hnd = logging.StreamHandler(self.buf)
logging.getLogger('sknn').addHandler(self.hnd)
logging.getLogger().setLevel(logging.WARNING)

def tearDown(self):
for name, module in self.removed.items():
sys.modules[name] = module
logging.getLogger('sknn').removeHandler(self.hnd)

def test_TheanoWarning(self):
pass
import theano
from sknn.backend import cpu
assert_equal('Theano was already imported and cannot be reconfigured.\n',
self.buf.getvalue())

def _check(self, flags):
assert_in('THEANO_FLAGS', os.environ)
Expand Down

0 comments on commit 7085c25

Please sign in to comment.