Skip to content

Commit

Permalink
Merge pull request #628 from pkorobov/fix-dot-recall
Browse files Browse the repository at this point in the history
Fix incorrect dot_factor usage
  • Loading branch information
erikbern committed Aug 20, 2023
2 parents 75429e5 + a70dac2 commit 2be37c9
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 24 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ We do this k times so that we get a forest of trees. k has to be tuned to your n

Hamming distance (contributed by `Martin Aumüller <https://github.com/maumueller>`__) packs the data into 64-bit integers under the hood and uses built-in bit count primitives so it could be quite fast. All splits are axis-aligned.

Dot Product distance (contributed by `Peter Sobot <https://github.com/psobot>`__) reduces the provided vectors from dot (or "inner-product") space to a more query-friendly cosine space using `a method by Bachrach et al., at Microsoft Research, published in 2014 <https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf>`__.
Dot Product distance (contributed by `Peter Sobot <https://github.com/psobot>`__ and `Pavel Korobov <https://github.com/pkorobov>`__) reduces the provided vectors from dot (or "inner-product") space to a more query-friendly cosine space using `a method by Bachrach et al., at Microsoft Research, published in 2014 <https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/XboxInnerProduct.pdf>`__.



Expand Down
112 changes: 97 additions & 15 deletions src/annoylib.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,6 @@ inline float euclidean_distance<float>(const float* x, const float* y, int f) {

#endif


template<typename T>
inline T get_norm(T* v, int f) {
return sqrt(dot(v, v, f));
}

template<typename T, typename Random, typename Distance, typename Node>
inline void two_means(const vector<Node*>& nodes, int f, Random& random, bool cosine, Node* p, Node* q) {
Expand Down Expand Up @@ -391,18 +386,16 @@ inline void two_means(const vector<Node*>& nodes, int f, Random& random, bool co
size_t k = random.index(count);
T di = ic * Distance::distance(p, nodes[k], f),
dj = jc * Distance::distance(q, nodes[k], f);
T norm = cosine ? get_norm(nodes[k]->v, f) : 1;
T norm = cosine ? Distance::template get_norm<T, Node>(nodes[k], f) : 1;
if (!(norm > T(0))) {
continue;
}
if (di < dj) {
for (int z = 0; z < f; z++)
p->v[z] = (p->v[z] * ic + nodes[k]->v[z] / norm) / (ic + 1);
Distance::update_mean(p, nodes[k], norm, ic, f);
Distance::init_node(p, f);
ic++;
} else if (dj < di) {
for (int z = 0; z < f; z++)
q->v[z] = (q->v[z] * jc + nodes[k]->v[z] / norm) / (jc + 1);
Distance::update_mean(q, nodes[k], norm, jc, f);
Distance::init_node(q, f);
jc++;
}
Expand All @@ -417,6 +410,12 @@ struct Base {
// on the entire set of nodes passed into this index.
}

template<typename T, typename S, typename Node>
static inline void postprocess(void* nodes, size_t _s, const S node_count, const int f) {
// Override this in specific metric structs below if you need to do any post-processing
// on the entire set of nodes passed into this index.
}

template<typename Node>
static inline void zero_value(Node* dest) {
// Initialize any fields that require sane defaults within this node.
Expand All @@ -427,14 +426,25 @@ struct Base {
memcpy(dest->v, source->v, f * sizeof(T));
}

template<typename T, typename Node>
static inline T get_norm(Node* node, int f) {
return sqrt(dot(node->v, node->v, f));
}

template<typename T, typename Node>
static inline void normalize(Node* node, int f) {
T norm = get_norm(node->v, f);
T norm = Base::get_norm<T, Node>(node, f);
if (norm > 0) {
for (int z = 0; z < f; z++)
node->v[z] /= norm;
}
}

template<typename T, typename Node>
static inline void update_mean(Node* mean, Node* new_node, T norm, int c, int f) {
for (int z = 0; z < f; z++)
mean->v[z] = (mean->v[z] * c + new_node->v[z] / norm) / (c + 1);
}
};

struct Angular : Base {
Expand Down Expand Up @@ -486,6 +496,10 @@ struct Angular : Base {
return (bool)random.flip();
}
template<typename S, typename T, typename Random>
static inline bool side(const Node<S, T>* n, const Node<S, T>* y, int f, Random& random) {
return side(n, y->v, f, random);
}
template<typename S, typename T, typename Random>
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
Node<S, T>* p = (Node<S, T>*)alloca(s);
Node<S, T>* q = (Node<S, T>*)alloca(s);
Expand Down Expand Up @@ -525,20 +539,50 @@ struct DotProduct : Angular {
template<typename S, typename T>
struct Node {
/*
* This is an extension of the Angular node with an extra attribute for the scaled norm.
* This is an extension of the Angular node with extra attributes for the DotProduct metric.
* It has dot_factor which is needed to reduce the task to Angular distance metric (see the preprocess method)
* and also a built flag that helps to compute exact dot products when an index is already built.
*/
S n_descendants;
S children[2]; // Will possibly store more than 2
T dot_factor;
T norm;
bool built;
T v[ANNOYLIB_V_ARRAY_SIZE];
};

static const char* name() {
return "dot";
}

template<typename T, typename Node>
static inline T get_norm(Node* node, int f) {
return sqrt(dot(node->v, node->v, f) + node->dot_factor * node->dot_factor);
}

template<typename T, typename Node>
static inline void update_mean(Node* mean, Node* new_node, T norm, int c, int f) {
for (int z = 0; z < f; z++)
mean->v[z] = (mean->v[z] * c + new_node->v[z] / norm) / (c + 1);
mean->dot_factor = (mean->dot_factor * c + new_node->dot_factor / norm) / (c + 1);
}

template<typename S, typename T>
static inline T distance(const Node<S, T>* x, const Node<S, T>* y, int f) {
return -dot(x->v, y->v, f);
if (x->built || y->built) {
// When index is already built, we don't need angular distances to retrieve NNs
// Thus, we can return dot product scores itself
return -dot(x->v, y->v, f);
}

// Calculated by analogy with the angular case
T pp = x->norm ? x->norm : dot(x->v, x->v, f) + x->dot_factor * x->dot_factor;
T qq = y->norm ? y->norm : dot(y->v, y->v, f) + y->dot_factor * y->dot_factor;
T pq = dot(x->v, y->v, f) + x->dot_factor * y->dot_factor;
T ppqq = pp * qq;

if (ppqq > 0) return 2.0 - 2.0 * pq / sqrt(ppqq);
else return 2.0;
}

template<typename Node>
Expand All @@ -548,6 +592,8 @@ struct DotProduct : Angular {

template<typename S, typename T>
static inline void init_node(Node<S, T>* n, int f) {
n->built = false;
n->norm = dot(n->v, n->v, f) + n->dot_factor * n->dot_factor;
}

template<typename T, typename Node>
Expand Down Expand Up @@ -581,7 +627,21 @@ struct DotProduct : Angular {

template<typename S, typename T>
static inline T margin(const Node<S, T>* n, const T* y, int f) {
return dot(n->v, y, f) + (n->dot_factor * n->dot_factor);
return dot(n->v, y, f);
}

template<typename S, typename T>
static inline T margin(const Node<S, T>* n, const Node<S, T>* y, int f) {
return dot(n->v, y->v, f) + n->dot_factor * y->dot_factor;
}

template<typename S, typename T, typename Random>
static inline bool side(const Node<S, T>* n, const Node<S, T>* y, int f, Random& random) {
T dot = margin(n, y, f);
if (dot != 0)
return (dot > 0);
else
return (bool)random.flip();
}

template<typename S, typename T, typename Random>
Expand Down Expand Up @@ -609,6 +669,7 @@ struct DotProduct : Angular {
T d = dot(node->v, node->v, f);
T norm = d < 0 ? 0 : sqrt(d);
node->dot_factor = norm;
node->built = false;
}

// Step two: find the maximum norm
Expand All @@ -627,9 +688,19 @@ struct DotProduct : Angular {
T squared_norm_diff = pow(max_norm, static_cast<T>(2.0)) - pow(node_norm, static_cast<T>(2.0));
T dot_factor = squared_norm_diff < 0 ? 0 : sqrt(squared_norm_diff);

node->norm = pow(max_norm, static_cast<T>(2.0));
node->dot_factor = dot_factor;
}
}

template<typename T, typename S, typename Node>
static inline void postprocess(void* nodes, size_t _s, const S node_count, const int f) {
for (S i = 0; i < node_count; i++) {
Node* node = get_node_ptr<S, Node>(nodes, _s, i);
// When an index is built, we will remember it in index item nodes to compute distances differently
node->built = true;
}
}
};

struct Hamming : Base {
Expand Down Expand Up @@ -681,6 +752,10 @@ struct Hamming : Base {
return margin(n, y, f);
}
template<typename S, typename T, typename Random>
static inline bool side(const Node<S, T>* n, const Node<S, T>* y, int f, Random& random) {
return side(n, y->v, f, random);
}
template<typename S, typename T, typename Random>
static inline void create_split(const vector<Node<S, T>*>& nodes, int f, size_t s, Random& random, Node<S, T>* n) {
size_t cur_size = 0;
size_t i = 0;
Expand Down Expand Up @@ -748,6 +823,10 @@ struct Minkowski : Base {
else
return (bool)random.flip();
}
template<typename S, typename T, typename Random>
static inline bool side(const Node<S, T>* n, const Node<S, T>* y, int f, Random& random) {
return side(n, y->v, f, random);
}
template<typename T>
static inline T pq_distance(T distance, T margin, int child_nr) {
if (child_nr == 0)
Expand Down Expand Up @@ -991,6 +1070,9 @@ template<typename S, typename T, typename Distance, typename Random, class Threa
}
_nodes_size = _n_nodes;
}

D::template postprocess<T, S, Node>(_nodes, _s, _n_items, _f);

_built = true;
return true;
}
Expand Down Expand Up @@ -1310,7 +1392,7 @@ template<typename S, typename T, typename Distance, typename Random, class Threa
S j = indices[i];
Node* n = _get(j);
if (n) {
bool side = D::side(m, n->v, _f, _random);
bool side = D::side(m, n, _f, _random);
children_indices[side].push_back(j);
} else {
annoylib_showUpdate("No node for index %d?\n", j);
Expand Down
34 changes: 26 additions & 8 deletions test/accuracy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from urllib.request import urlretrieve # Python 3


def _get_index(dataset):
url = "http://vectors.erikbern.com/%s.hdf5" % dataset
def _get_index(dataset, custom_distance=None, custom_dim=None):
url = 'http://ann-benchmarks.com/%s.hdf5' % dataset
vectors_fn = os.path.join("test", dataset + ".hdf5")
index_fn = os.path.join("test", dataset + ".annoy")

Expand All @@ -37,29 +37,42 @@ def _get_index(dataset):

dataset_f = h5py.File(vectors_fn, "r")
distance = dataset_f.attrs["distance"]
if custom_distance is not None:
distance = custom_distance
f = dataset_f["train"].shape[1]
if custom_dim:
f = custom_dim
if custom_distance:
dataset = dataset.rsplit('-', 2)[0] + "-%d-%s" % (f, custom_distance)
index_fn = os.path.join('test', dataset + '.annoy')


annoy = AnnoyIndex(f, distance)

if not os.path.exists(index_fn):
print("adding items", distance, f)
for i, v in enumerate(dataset_f["train"]):
if len(v) > f:
v = v[:f]
annoy.add_item(i, v)

print("building index")
annoy.build(10)
annoy.save(index_fn)
else:
annoy.load(index_fn)
return annoy, dataset_f
return annoy, dataset_f, dataset


def _test_index(dataset, exp_accuracy):
annoy, dataset_f = _get_index(dataset)
def _test_index(dataset, exp_accuracy, custom_metric=None, custom_dim=None):
annoy, dataset_f, dataset = _get_index(dataset, custom_metric, custom_dim)

n, k = 0, 0

for i, v in enumerate(dataset_f["test"]):
js_fast = annoy.get_nns_by_vector(v, 10, 1000)
if custom_dim:
v = v[:custom_dim]
js_fast = annoy.get_nns_by_vector(v, 10, 10000)
js_real = dataset_f["neighbors"][i][:10]
assert len(js_fast) == 10
assert len(js_real) == 10
Expand All @@ -72,6 +85,7 @@ def _test_index(dataset, exp_accuracy):
"%50s accuracy: %5.2f%% (expected %5.2f%%)" % (dataset, accuracy, exp_accuracy)
)


assert accuracy > exp_accuracy - 1.0 # should be within 1%


Expand All @@ -83,5 +97,9 @@ def test_nytimes_16():
_test_index("nytimes-16-angular", 80.00)


def test_fashion_mnist():
_test_index("fashion-mnist-784-euclidean", 90.00)
def test_lastfm_dot():
_test_index('lastfm-64-dot', 60.00, 'dot', 64)


def test_lastfm_angular():
_test_index('lastfm-64-dot', 60.00, 'angular', 65)
2 changes: 2 additions & 0 deletions test/dot_index_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def test_dist():
i.add_item(0, [0, 1])
i.add_item(1, [1, 1])
i.add_item(2, [0, 0])
i.build(10)

assert i.get_distance(0, 1) == pytest.approx(1.0)
assert i.get_distance(1, 2) == pytest.approx(0.0)
Expand Down Expand Up @@ -161,3 +162,4 @@ def test_distance_consistency():
numpy.dot(i.get_item_vector(a), i.get_item_vector(b))
)
assert dist == pytest.approx(i.get_distance(a, b))

0 comments on commit 2be37c9

Please sign in to comment.