@@ -859,24 +859,50 @@ cdef class MultiIndexHashTable(HashTable):
859859 sizeof(size_t) + # vals
860860 sizeof(uint32_t)) # flags
861861
862+ def _check_for_collisions(self, int64_t[:] locs, object mi):
863+ # validate that the locs map to the actual values
864+ # provided in the mi
865+ # we can only check if we *don't* have any missing values
866+ # :<
867+ cdef:
868+ ndarray[int64_t] alocs
869+
870+ alocs = np.asarray(locs)
871+ if (alocs!=-1).all():
872+
873+ result = self.mi.take(locs)
874+ if not result.equals(mi):
875+ raise ValueError("hash collision alert")
876+
862877 def __contains__(self, object key):
863878 cdef:
864879 khiter_t k
865880 uint64_t value
866881
867882 value = self.mi._hashed_indexing_key(key)
868883 k = kh_get_uint64(self.table, value)
869- return k != self.table.n_buckets
884+ if k != self.table.n_buckets:
885+ loc = self.table.vals[k]
886+ locs = np.array([loc], dtype=np.int64)
887+ self._check_for_collisions(locs, key)
888+ return True
889+
890+ return False
870891
871892 cpdef get_item(self, object key):
872893 cdef:
873894 khiter_t k
874895 uint64_t value
896+ int64_t[:] locs
897+ Py_ssize_t loc
875898
876899 value = self.mi._hashed_indexing_key(key)
877900 k = kh_get_uint64(self.table, value)
878901 if k != self.table.n_buckets:
879- return self.table.vals[k]
902+ loc = self.table.vals[k]
903+ locs = np.array([loc], dtype=np.int64)
904+ self._check_for_collisions(locs, key)
905+ return loc
880906 else:
881907 raise KeyError(key)
882908
@@ -927,6 +953,7 @@ cdef class MultiIndexHashTable(HashTable):
927953 else:
928954 locs[i] = -1
929955
956+ self._check_for_collisions(locs, mi)
930957 return np.asarray(locs)
931958
932959 def unique(self, object mi):
0 commit comments