Skip to content

Commit

Permalink
Merge pull request apache#94 from oist/vc/generic_metric_update
Browse files Browse the repository at this point in the history
Deduplicate and generalise metric update!
  • Loading branch information
pluskid committed May 13, 2016
2 parents baa9c2a + 64a63f7 commit 0ea4369
Showing 1 changed file with 13 additions and 27 deletions.
40 changes: 13 additions & 27 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ set.
=#
abstract AbstractEvalMetric

# Generic update! version
function update!{T <: AbstractEvalMetric}(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray})
if length(labels) != length(preds)
Base.warn_once(
"The number of labels ($(length(labels))) does not correspond to the\
number of outputs ($(length(preds))). The calculated metric might not be accuracte.")
end
for (label, pred) in zip(labels, preds)
_update_single_output(metric, label, pred)
end
end


#=doc
.. class:: Accuracy
Expand Down Expand Up @@ -85,13 +98,6 @@ function _update_single_output(metric :: Accuracy, label :: NDArray, pred :: NDA
end
end

function update!(metric :: Accuracy, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

import Base: get
function get(metric :: Accuracy)
return [(:accuracy, metric.acc_sum / metric.n_sample)]
Expand Down Expand Up @@ -129,13 +135,6 @@ function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray)
end
end

function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

function get(metric :: MSE)
return [(:MSE, metric.mse_sum / metric.n_sample)]
end
Expand Down Expand Up @@ -193,13 +192,6 @@ function _update_single_output(metric :: ACE, label :: NDArray, pred :: NDArray)
end
end

function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

#=doc
.. class:: MultiACE
Expand Down Expand Up @@ -251,9 +243,3 @@ function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDA
end
end

function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

0 comments on commit 0ea4369

Please sign in to comment.