7
7
import os
8
8
import math
9
9
from typing import Optional , Union , NamedTuple , List , Iterable
10
+ from dataclasses import dataclass
10
11
11
12
import numpy as np
12
13
from tqdm import tqdm
@@ -101,24 +102,28 @@ def _normalize_metric(metric):
101
102
102
103
return metric
103
104
104
-
105
- class Match ( NamedTuple ) :
105
+ @ dataclass
106
+ class Match :
106
107
label : int
107
108
distance : float
108
109
109
110
110
- class Matches (NamedTuple ):
111
+ @dataclass
112
+ class Matches :
111
113
labels : np .ndarray
112
114
distances : np .ndarray
113
115
114
116
def __len__ (self ) -> int :
115
117
return len (self .labels )
116
118
117
119
def __getitem__ (self , index : int ) -> Match :
118
- return Match (
119
- label = self .labels [index ],
120
- distance = self .distances [index ],
121
- )
120
+ if isinstance (index , int ) and index < len (self ):
121
+ return Match (
122
+ label = self .labels [index ],
123
+ distance = self .distances [index ],
124
+ )
125
+ else :
126
+ raise IndexError (f"`index` must be an integer under { len (self )} " )
122
127
123
128
def to_list (self ) -> List [tuple ]:
124
129
return [(int (l ), float (d )) for l , d in zip (self .labels , self .distances )]
@@ -127,7 +132,8 @@ def __repr__(self) -> str:
127
132
return f"usearch.Matches({ len (self )} )"
128
133
129
134
130
- class BatchMatches (NamedTuple ):
135
+ @dataclass
136
+ class BatchMatches :
131
137
labels : np .ndarray
132
138
distances : np .ndarray
133
139
counts : np .ndarray
@@ -136,10 +142,13 @@ def __len__(self) -> int:
136
142
return len (self .counts )
137
143
138
144
def __getitem__ (self , index : int ) -> Matches :
139
- return Matches (
140
- labels = self .labels [index , : self .counts [index ]],
141
- distances = self .distances [index , : self .counts [index ]],
142
- )
145
+ if isinstance (index , int ) and index < len (self ):
146
+ return Matches (
147
+ labels = self .labels [index , : self .counts [index ]],
148
+ distances = self .distances [index , : self .counts [index ]],
149
+ )
150
+ else :
151
+ raise IndexError (f"`index` must be an integer under { len (self )} " )
143
152
144
153
def to_list (self ) -> List [List [tuple ]]:
145
154
lists = [self .__getitem__ (row ) for row in range (self .__len__ ())]
0 commit comments