Skip to content

Commit 54cecb6

Browse files
committed
Improve: Out-of-bounds checks
1 parent 1b40f13 commit 54cecb6

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

python/usearch/index.py

+21-12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import math
99
from typing import Optional, Union, NamedTuple, List, Iterable
10+
from dataclasses import dataclass
1011

1112
import numpy as np
1213
from tqdm import tqdm
@@ -101,24 +102,28 @@ def _normalize_metric(metric):
101102

102103
return metric
103104

104-
105-
class Match(NamedTuple):
105+
@dataclass
106+
class Match:
106107
label: int
107108
distance: float
108109

109110

110-
class Matches(NamedTuple):
111+
@dataclass
112+
class Matches:
111113
labels: np.ndarray
112114
distances: np.ndarray
113115

114116
def __len__(self) -> int:
115117
return len(self.labels)
116118

117119
def __getitem__(self, index: int) -> Match:
118-
return Match(
119-
label=self.labels[index],
120-
distance=self.distances[index],
121-
)
120+
if isinstance(index, int) and index < len(self):
121+
return Match(
122+
label=self.labels[index],
123+
distance=self.distances[index],
124+
)
125+
else:
126+
raise IndexError(f"`index` must be an integer under {len(self)}")
122127

123128
def to_list(self) -> List[tuple]:
124129
return [(int(l), float(d)) for l, d in zip(self.labels, self.distances)]
@@ -127,7 +132,8 @@ def __repr__(self) -> str:
127132
return f"usearch.Matches({len(self)})"
128133

129134

130-
class BatchMatches(NamedTuple):
135+
@dataclass
136+
class BatchMatches:
131137
labels: np.ndarray
132138
distances: np.ndarray
133139
counts: np.ndarray
@@ -136,10 +142,13 @@ def __len__(self) -> int:
136142
return len(self.counts)
137143

138144
def __getitem__(self, index: int) -> Matches:
139-
return Matches(
140-
labels=self.labels[index, : self.counts[index]],
141-
distances=self.distances[index, : self.counts[index]],
142-
)
145+
if isinstance(index, int) and index < len(self):
146+
return Matches(
147+
labels=self.labels[index, : self.counts[index]],
148+
distances=self.distances[index, : self.counts[index]],
149+
)
150+
else:
151+
raise IndexError(f"`index` must be an integer under {len(self)}")
143152

144153
def to_list(self) -> List[List[tuple]]:
145154
lists = [self.__getitem__(row) for row in range(self.__len__())]

0 commit comments

Comments
 (0)