diff --git a/docs/source/perfeval.rst b/docs/source/perfeval.rst index 0b8e6ef..f03c503 100644 --- a/docs/source/perfeval.rst +++ b/docs/source/perfeval.rst @@ -14,12 +14,12 @@ Classification Performance Compute error rate of predictions given by ``pred`` w.r.t. the ground truths given in ``gt``. -.. function:: confusmat(k, gt, pred) +.. function:: confusmat(gt, pred) Compute the confusion matrix of the predictions given by ``pred`` w.r.t. the ground truths given in ``gt``. - Here, ``k`` is the number of classes. - It returns an integer matrix ``R`` of size ``(k, k)``, such that ``R(i, j) == countnz((gt .== i) & (pred .== j))``. + It returns an integer matrix ``R`` of size ``(k, k)`` where k is the number of classes in ``gt``, + such that ``R(i, j) == countnz((gt .== i) & (pred .== j))``. **Examples:** @@ -29,7 +29,7 @@ Classification Performance julia> pred = [1, 1, 2, 2, 2, 3, 3, 3]; - julia> C = confusmat(3, gt, pred) # compute confusion matrix + julia> C = confusmat(gt, pred) # compute confusion matrix 3x3 Array{Int64,2}: 2 1 0 0 2 1 diff --git a/src/perfeval.jl b/src/perfeval.jl index 4f521e2..b67491d 100644 --- a/src/perfeval.jl +++ b/src/perfeval.jl @@ -6,14 +6,18 @@ correctrate(gt::IntegerVector, r::IntegerVector) = counteq(gt, r) / length(gt) errorrate(gt::IntegerVector, r::IntegerVector) = countne(gt, r) / length(gt) ## confusion matrix - -function confusmat(k::Integer, gts::IntegerVector, preds::IntegerVector) +function confusmat(gts::IntegerVector, preds::IntegerVector) n = length(gts) length(preds) == n || throw(DimensionMismatch("Inconsistent lengths.")) + + gtslbl = sort(unique(gts)) + k = length(gtslbl) + + lookup = Dict(reverse.(enumerate(gtslbl)|> collect)) R = zeros(Int, k, k) for i = 1:n - @inbounds g = gts[i] - @inbounds p = preds[i] + @inbounds g = lookup[gts[i]] + @inbounds p = lookup[preds[i]] R[g, p] += 1 end return R