diff --git a/hands_on/pyanno_voting/pyanno/tests/test_voting.py b/hands_on/pyanno_voting/pyanno/tests/test_voting.py index f21a10c..c6092b0 100644 --- a/hands_on/pyanno_voting/pyanno/tests/test_voting.py +++ b/hands_on/pyanno_voting/pyanno/tests/test_voting.py @@ -5,6 +5,7 @@ def test_labels_count(): + #Given annotations = [ [1, 2, MV, MV], [MV, MV, 3, 3], @@ -13,7 +14,11 @@ def test_labels_count(): ] nclasses = 5 expected = [0, 3, 1, 3, 0] + + #When result = voting.labels_count(annotations, nclasses) + + #Then assert result == expected @@ -41,3 +46,20 @@ def test_majority_vote_empty_item(): expected = [1, MV, 2] result = voting.majority_vote(annotations) assert result == expected + + +def test_label_frequency(): + #Given + annotations = np.array( + [[1, 2, 3], + [-1, -1, -1], + [1, 2, 2]] + ) + n_classes = 4 + expected = np.array([2/6, 3/6, 1/6, 0/6]) + + #When + result = voting.labels_frequency(annotations, n_classes) + + #Then + assert np.allclose(result, expected, rtol = 1e-6) diff --git a/hands_on/pyanno_voting/pyanno/voting.py b/hands_on/pyanno_voting/pyanno/voting.py index d5b5747..af1044a 100644 --- a/hands_on/pyanno_voting/pyanno/voting.py +++ b/hands_on/pyanno_voting/pyanno/voting.py @@ -100,3 +100,11 @@ def labels_frequency(annotations, nclasses): freq[k] is the frequency of elements of class k in `annotations`, i.e. their count over the number of total of observed (non-missing) elements """ + + result = np.array([0]*nclasses) + for val in np.array(annotations).flatten(): + if val > 0: + result[val-1] += 1 + + + return result/sum(result)