Skip to content

Commit

Permalink
Fix confusmat error described in issue JuliaStats#35
Browse files Browse the repository at this point in the history
  • Loading branch information
asbisen committed Sep 4, 2018
1 parent 6221277 commit e245f24
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
8 changes: 4 additions & 4 deletions docs/source/perfeval.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ Classification Performance

Compute error rate of predictions given by ``pred`` w.r.t. the ground truths given in ``gt``.

.. function:: confusmat(k, gt, pred)
.. function:: confusmat(gt, pred)

Compute the confusion matrix of the predictions given by ``pred`` w.r.t. the ground truths given in ``gt``.
Here, ``k`` is the number of classes.

It returns an integer matrix ``R`` of size ``(k, k)``, such that ``R(i, j) == countnz((gt .== i) & (pred .== j))``.
It returns an integer matrix ``R`` of size ``(k, k)`` where k is the number of classes in ``gt``,
such that ``R(i, j) == countnz((gt .== i) & (pred .== j))``.

**Examples:**

Expand All @@ -29,7 +29,7 @@ Classification Performance
julia> pred = [1, 1, 2, 2, 2, 3, 3, 3];
julia> C = confusmat(3, gt, pred) # compute confusion matrix
julia> C = confusmat(gt, pred) # compute confusion matrix
3x3 Array{Int64,2}:
2 1 0
0 2 1
Expand Down
12 changes: 8 additions & 4 deletions src/perfeval.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,18 @@ correctrate(gt::IntegerVector, r::IntegerVector) = counteq(gt, r) / length(gt)
errorrate(gt::IntegerVector, r::IntegerVector) = countne(gt, r) / length(gt)

## confusion matrix

function confusmat(k::Integer, gts::IntegerVector, preds::IntegerVector)
function confusmat(gts::IntegerVector, preds::IntegerVector)
n = length(gts)
length(preds) == n || throw(DimensionMismatch("Inconsistent lengths."))

gtslbl = sort(unique(gts))
k = length(gtslbl)

lookup = Dict(reverse.(enumerate(gtslbl)|> collect))
R = zeros(Int, k, k)
for i = 1:n
@inbounds g = gts[i]
@inbounds p = preds[i]
@inbounds g = lookup[gts[i]]
@inbounds p = lookup[preds[i]]
R[g, p] += 1
end
return R
Expand Down

0 comments on commit e245f24

Please sign in to comment.