Skip to content

Commit 491bbb6

Browse files
authored
Lite Classification Accuracy Fix (#798)
1 parent 95f44b3 commit 491bbb6

File tree

4 files changed

+46
-115
lines changed

4 files changed

+46
-115
lines changed

lite/tests/classification/test_accuracy.py

+7-90
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,12 @@ def test_accuracy_computation():
5353
)
5454

5555
# score threshold, label, count metric
56-
assert accuracy.shape == (2, 4)
56+
assert accuracy.shape == (2,)
5757

5858
# score >= 0.25
59-
assert accuracy[0][0] == 2 / 3
60-
assert accuracy[0][1] == 1.0
61-
assert accuracy[0][2] == 2 / 3
62-
assert accuracy[0][3] == 1.0
59+
assert accuracy[0] == 2 / 3
6360
# score >= 0.75
64-
assert accuracy[1][0] == 2 / 3
65-
assert accuracy[1][1] == 1.0
66-
assert accuracy[1][2] == 2 / 3
67-
assert accuracy[1][3] == 2 / 3
61+
assert accuracy[1] == 1 / 3
6862

6963

7064
def test_accuracy_basic(basic_classifications: list[Classification]):
@@ -87,20 +81,10 @@ def test_accuracy_basic(basic_classifications: list[Classification]):
8781
expected_metrics = [
8882
{
8983
"type": "Accuracy",
90-
"value": [2 / 3, 2 / 3],
84+
"value": [2 / 3, 1 / 3],
9185
"parameters": {
9286
"score_thresholds": [0.25, 0.75],
9387
"hardmax": True,
94-
"label": "0",
95-
},
96-
},
97-
{
98-
"type": "Accuracy",
99-
"value": [1.0, 2 / 3],
100-
"parameters": {
101-
"score_thresholds": [0.25, 0.75],
102-
"hardmax": True,
103-
"label": "3",
10488
},
10589
},
10690
]
@@ -124,29 +108,10 @@ def test_accuracy_with_animal_example(
124108
expected_metrics = [
125109
{
126110
"type": "Accuracy",
127-
"value": [2.0 / 3.0],
128-
"parameters": {
129-
"score_thresholds": [0.5],
130-
"hardmax": True,
131-
"label": "bird",
132-
},
133-
},
134-
{
135-
"type": "Accuracy",
136-
"value": [0.5],
111+
"value": [2.0 / 6.0],
137112
"parameters": {
138113
"score_thresholds": [0.5],
139114
"hardmax": True,
140-
"label": "dog",
141-
},
142-
},
143-
{
144-
"type": "Accuracy",
145-
"value": [2 / 3],
146-
"parameters": {
147-
"score_thresholds": [0.5],
148-
"hardmax": True,
149-
"label": "cat",
150115
},
151116
},
152117
]
@@ -170,38 +135,10 @@ def test_accuracy_color_example(
170135
expected_metrics = [
171136
{
172137
"type": "Accuracy",
173-
"value": [2 / 3],
174-
"parameters": {
175-
"score_thresholds": [0.5],
176-
"hardmax": True,
177-
"label": "white",
178-
},
179-
},
180-
{
181-
"type": "Accuracy",
182-
"value": [2 / 3],
138+
"value": [2 / 6],
183139
"parameters": {
184140
"score_thresholds": [0.5],
185141
"hardmax": True,
186-
"label": "red",
187-
},
188-
},
189-
{
190-
"type": "Accuracy",
191-
"value": [2 / 3],
192-
"parameters": {
193-
"score_thresholds": [0.5],
194-
"hardmax": True,
195-
"label": "blue",
196-
},
197-
},
198-
{
199-
"type": "Accuracy",
200-
"value": [5 / 6],
201-
"parameters": {
202-
"score_thresholds": [0.5],
203-
"hardmax": True,
204-
"label": "black",
205142
},
206143
},
207144
]
@@ -237,7 +174,6 @@ def test_accuracy_with_image_example(
237174
"parameters": {
238175
"score_thresholds": [0.0],
239176
"hardmax": True,
240-
"label": "v4",
241177
},
242178
},
243179
]
@@ -269,29 +205,10 @@ def test_accuracy_with_tabular_example(
269205
expected_metrics = [
270206
{
271207
"type": "Accuracy",
272-
"value": [0.7],
273-
"parameters": {
274-
"score_thresholds": [0.0],
275-
"hardmax": True,
276-
"label": "0",
277-
},
278-
},
279-
{
280-
"type": "Accuracy",
281-
"value": [0.5],
282-
"parameters": {
283-
"score_thresholds": [0.0],
284-
"hardmax": True,
285-
"label": "1",
286-
},
287-
},
288-
{
289-
"type": "Accuracy",
290-
"value": [0.8],
208+
"value": [5 / 10],
291209
"parameters": {
292210
"score_thresholds": [0.0],
293211
"hardmax": True,
294-
"label": "2",
295212
},
296213
},
297214
]

lite/valor_lite/classification/computation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,9 @@ def compute_metrics(
182182
out=precision,
183183
)
184184

185-
accuracy = np.zeros_like(recall)
185+
accuracy = np.zeros(n_scores, dtype=np.float64)
186186
np.divide(
187-
(counts[:, :, 0] + counts[:, :, 3]),
187+
counts[:, :, 0].sum(axis=1),
188188
float(n_datums),
189189
out=accuracy,
190190
)

lite/valor_lite/classification/manager.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,14 @@ def compute_precision_recall(
367367
)
368368
]
369369

370+
metrics[MetricType.Accuracy] = [
371+
Accuracy(
372+
value=accuracy.astype(float).tolist(),
373+
score_thresholds=score_thresholds,
374+
hardmax=hardmax,
375+
)
376+
]
377+
370378
for label_idx, label in self.index_to_label.items():
371379

372380
kwargs = {
@@ -401,12 +409,6 @@ def compute_precision_recall(
401409
**kwargs,
402410
)
403411
)
404-
metrics[MetricType.Accuracy].append(
405-
Accuracy(
406-
value=accuracy[:, label_idx].astype(float).tolist(),
407-
**kwargs,
408-
)
409-
)
410412
metrics[MetricType.F1].append(
411413
F1(
412414
value=f1_score[:, label_idx].astype(float).tolist(),

lite/valor_lite/classification/metric.py

+29-17
Original file line numberDiff line numberDiff line change
@@ -158,24 +158,23 @@ class Recall(_ThresholdValue):
158158
pass
159159

160160

161-
class Accuracy(_ThresholdValue):
161+
class F1(_ThresholdValue):
162162
"""
163-
Accuracy metric for a specific class label.
163+
F1 score for a specific class label.
164164
165-
This class calculates the accuracy at various score thresholds for a binary
166-
classification task. Accuracy is defined as the ratio of the sum of true positives and
167-
true negatives over all predictions.
165+
This class calculates the F1 score at various score thresholds for a binary
166+
classification task.
168167
169168
Attributes
170169
----------
171170
value : list[float]
172-
Accuracy values computed at each score threshold.
171+
F1 scores computed at each score threshold.
173172
score_thresholds : list[float]
174-
Score thresholds at which the accuracy values are computed.
173+
Score thresholds at which the F1 scores are computed.
175174
hardmax : bool
176175
Indicates whether hardmax thresholding was used.
177176
label : str
178-
The class label for which the accuracy is computed.
177+
The class label for which the F1 score is computed.
179178
180179
Methods
181180
-------
@@ -188,23 +187,21 @@ class Accuracy(_ThresholdValue):
188187
pass
189188

190189

191-
class F1(_ThresholdValue):
190+
@dataclass
191+
class Accuracy:
192192
"""
193-
F1 score for a specific class label.
193+
Multiclass accuracy metric.
194194
195-
This class calculates the F1 score at various score thresholds for a binary
196-
classification task.
195+
This class calculates the accuracy at various score thresholds.
197196
198197
Attributes
199198
----------
200199
value : list[float]
201-
F1 scores computed at each score threshold.
200+
Accuracy values computed at each score threshold.
202201
score_thresholds : list[float]
203-
Score thresholds at which the F1 scores are computed.
202+
Score thresholds at which the accuracy values are computed.
204203
hardmax : bool
205204
Indicates whether hardmax thresholding was used.
206-
label : str
207-
The class label for which the F1 score is computed.
208205
209206
Methods
210207
-------
@@ -214,7 +211,22 @@ class F1(_ThresholdValue):
214211
Converts the instance to a dictionary representation.
215212
"""
216213

217-
pass
214+
value: list[float]
215+
score_thresholds: list[float]
216+
hardmax: bool
217+
218+
def to_metric(self) -> Metric:
219+
return Metric(
220+
type=type(self).__name__,
221+
value=self.value,
222+
parameters={
223+
"score_thresholds": self.score_thresholds,
224+
"hardmax": self.hardmax,
225+
},
226+
)
227+
228+
def to_dict(self) -> dict:
229+
return self.to_metric().to_dict()
218230

219231

220232
@dataclass

0 commit comments

Comments
 (0)