Skip to content

Commit

Permalink
Fix relevancy cache and normalize
Browse files Browse the repository at this point in the history
  • Loading branch information
vbhavank committed Aug 14, 2020
1 parent 59e9fa4 commit c719e45
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 28 deletions.
31 changes: 20 additions & 11 deletions python/smqtk/algorithms/nn_index/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@
# Requires FAISS bindings
try:
import faiss
import sklearn
except ImportError:
faiss = None
sklearn = None


class FaissNearestNeighborsIndex (NearestNeighborsIndex):
Expand All @@ -50,10 +52,18 @@ def gpu_supported():
else:
return False

@staticmethod
def normalize_vec(data, min_range=0, max_range=1):
data = sklearn.preprocessing.minmax_scale(X, feature_range=(min_range, max_range), axis=1, copy=False)
return data

@classmethod
def is_usable(cls):
# if underlying library is not found, the import above will error
return faiss is not None
if (faiss is not None) and (sklearn is not None):
return True
else:
return False

@classmethod
def get_default_config(cls):
Expand Down Expand Up @@ -183,6 +193,10 @@ def __init__(self, descriptor_set, idx2uid_kvs, uid2idx_kvs,
existing index. False by default.
:type read_only: bool
:param distsance_m: Key for selecting metric used during indexing
and retireval. 'cosine' and 'euclidean' are currently supported
:type distance_m: str
:param factory_string: String to pass to FAISS' `index_factory`;
see the documentation [1] on this feature for more details.
:type factory_string: str | unicode
Expand Down Expand Up @@ -436,11 +450,9 @@ def _build_index(self, descriptors):

faiss_index = self._index_factory_wrapper(d, self.factory_string)
# noinspection PyArgumentList
if self._distance_metric:
data = (
data / np.linalg.norm(
data, axis=1, keepdims=True)
)
if self._distance_metric == 'cosine':
# Normalizing vector before using L2 will result in cosine distance.
data = normalize_vec(data)
faiss_index.train(data)
# TODO(john.moeller): This will raise an exception on flat indexes.
# There's a solution which involves wrapping the index in an
Expand Down Expand Up @@ -651,11 +663,8 @@ def _nn(self, d, n=1):
"""
q = d.vector()[np.newaxis, :].astype(np.float32)
if self._distance_metric:
q = (
q / np.linalg.norm(
q, axis=1, keepdims=True)
)
if self._distance_metric == 'cosine':
q = normalize_vec(q)
self._log.debug("Received query for %d nearest neighbors", n)

with self._model_lock:
Expand Down
18 changes: 1 addition & 17 deletions python/smqtk/algorithms/relevancy_index/logistic_reg.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ class LogisticRegRelevancyIndex (RelevancyIndex):
to implement IQR ranking.
"""

# Dictionary of parameter/value pairs that will be passed to libSVM during
# the model trail phase. Parameters that are flags, i.e. have no values,
# should be given an empty string ('') value.
LR_TRAIN_PARAMS = {
"penalty": "l2",
"dual": True,
Expand All @@ -47,16 +44,13 @@ def is_usable(cls):
"""
return LogisticRegression and sklearn

def __init__(self, descr_cache_filepath=None, autoneg_select_ratio=1,
def __init__(self, autoneg_select_ratio=1,
multiprocess_fetch=False, cores=None):
"""
Initialize a new or existing index.
TODO ::
- input optional known background descriptors, i.e. descriptors for
things that would otherwise always be considered a negative example.
:param descr_cache_filepath: Optional path to store/load descriptors
we index.
:type descr_cache_filepath: None | str
:param autoneg_select_ratio: Number of maximally distant descriptors to
select from our descriptor cache for each positive example provided
when no negative examples are provided for ranking.
Expand All @@ -73,7 +67,6 @@ def __init__(self, descr_cache_filepath=None, autoneg_select_ratio=1,
"""
super(LogisticRegRelevancyIndex, self).__init__()

self.descr_cache_fp = descr_cache_filepath
self.autoneg_select_ratio = int(autoneg_select_ratio)
self.multiprocess_fetch = multiprocess_fetch
self.cores = cores
Expand All @@ -87,20 +80,12 @@ def __init__(self, descr_cache_filepath=None, autoneg_select_ratio=1,
# subsequently in the distance kernel
self._descr2index = {}

if self.descr_cache_fp and osp.exists(self.descr_cache_fp):
with open(self.descr_cache_fp, 'rb') as f:
descriptors = pickle.load(f)
self.descr_cache_fp = None
self.build_index(descriptors)
self.descr_cache_fp = descr_cache_filepath

@classmethod
def _gen_lr_parameter_string(cls):
return cls.LR_TRAIN_PARAMS

def get_config(self):
return {
"descr_cache_filepath": self.descr_cache_fp,
'autoneg_select_ratio': self.autoneg_select_ratio,
'multiprocess_fetch': self.multiprocess_fetch,
'cores': self.cores,
Expand Down Expand Up @@ -149,7 +134,6 @@ def get_vector(d_elem):
self._descr2index[tuple(v)] = i
self._descr_matrix = numpy.array(self._descr_matrix)


def rank(self, pos, neg):
"""
Rank the currently indexed elements given ``pos`` positive and ``neg``
Expand Down

0 comments on commit c719e45

Please sign in to comment.