diff --git a/src/metric.jl b/src/metric.jl index 7916d45b639c..7e76b969d0a0 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -36,6 +36,19 @@ set. =# abstract AbstractEvalMetric +# Generic update! version +function update!{T <: AbstractEvalMetric}(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray}) + if length(labels) != length(preds) + Base.warn_once( + "The number of labels ($(length(labels))) does not correspond to the\ + number of outputs ($(length(preds))). The calculated metric might not be accuracte.") + end + for (label, pred) in zip(labels, preds) + _update_single_output(metric, label, pred) + end +end + + #=doc .. class:: Accuracy @@ -85,13 +98,6 @@ function _update_single_output(metric :: Accuracy, label :: NDArray, pred :: NDA end end -function update!(metric :: Accuracy, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - import Base: get function get(metric :: Accuracy) return [(:accuracy, metric.acc_sum / metric.n_sample)] @@ -129,13 +135,6 @@ function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray) end end -function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - function get(metric :: MSE) return [(:MSE, metric.mse_sum / metric.n_sample)] end @@ -193,13 +192,6 @@ function _update_single_output(metric :: ACE, label :: NDArray, pred :: NDArray) end end -function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - #=doc .. class:: MultiACE @@ -251,9 +243,3 @@ function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDA end end -function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end