Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Making Features as a singleton for improved caching
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Aug 9, 2019
1 parent bfd3bb8 commit 914667d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
22 changes: 14 additions & 8 deletions python/mxnet/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,19 @@ class Features(collections.OrderedDict):
"""
OrderedDict of name to Feature
"""
def __init__(self):
super(Features, self).__init__([(f.name, f) for f in feature_list()])
instance = None
def __new__(cls):
if cls.instance is None:
cls.instance = super().__new__(cls)
return cls.instance

def __repr__(self):
return str(list(self.values()))
def __init__(cls):
super(Features, cls).__init__([(f.name, f) for f in feature_list()])

def __repr__(cls):
return str(list(cls.values()))

def is_enabled(self, feature_name):
def is_enabled(cls, feature_name):
"""
Check for a particular feature by name
Expand All @@ -94,7 +100,7 @@ def is_enabled(self, feature_name):
True if it's enabled, False if it's disabled, RuntimeError if the feature is not known
"""
feature_name = feature_name.upper()
if feature_name not in self:
if feature_name not in cls:
raise RuntimeError("Feature '{}' is unknown, known features are: {}".format(
feature_name, list(self.keys())))
return self[feature_name].enabled
feature_name, list(cls.keys())))
return cls[feature_name].enabled
9 changes: 9 additions & 0 deletions tests/python/unittest/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@
from mxnet.base import MXNetError
from nose.tools import *


def test_features():
features = Features()
print(features)
ok_('CUDA' in features)
ok_(len(features) >= 30)


def test_is_singleton():
x = Features()
y = Features()
assert x is y


def test_is_enabled():
features = Features()
for f in features:
Expand All @@ -35,6 +43,7 @@ def test_is_enabled():
else:
ok_(not features.is_enabled(f))


@raises(RuntimeError)
def test_is_enabled_not_existing():
features = Features()
Expand Down

0 comments on commit 914667d

Please sign in to comment.