diff --git a/hands_on/pyanno_voting/pyanno/tests/test_voting.py b/hands_on/pyanno_voting/pyanno/tests/test_voting.py index f21a10c..2cbf4de 100644 --- a/hands_on/pyanno_voting/pyanno/tests/test_voting.py +++ b/hands_on/pyanno_voting/pyanno/tests/test_voting.py @@ -41,3 +41,10 @@ def test_majority_vote_empty_item(): expected = [1, MV, 2] result = voting.majority_vote(annotations) assert result == expected + + +def test_labels_frequency(): + result = voting.labels_frequency([[1, 1, 2], [MV, 1, 2]], 4) + expected = np.array([ 0. , 0.6, 0.4, 0. ]) + np.testing.assert_array_equal(result, expected) + diff --git a/hands_on/pyanno_voting/pyanno/voting.py b/hands_on/pyanno_voting/pyanno/voting.py index d5b5747..a461e08 100644 --- a/hands_on/pyanno_voting/pyanno/voting.py +++ b/hands_on/pyanno_voting/pyanno/voting.py @@ -100,3 +100,21 @@ def labels_frequency(annotations, nclasses): freq[k] is the frequency of elements of class k in `annotations`, i.e. their count over the number of total of observed (non-missing) elements """ + + annotations = np.asarray(annotations) + nitems = annotations.shape[0] + valid = annotations != MISSING_VALUE + + annotations_without_mv = annotations[valid] + + unique, counts = np.unique(annotations_without_mv, return_counts=True) + + output = np.zeros(nclasses) + j = 0 + for i in range(nclasses): + if i in unique: + output[i] = counts[j] / sum(counts) + j += 1 + + return(output) +