From 64a63f7797b946e334e38fbb08bd7ddd79634c12 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Wed, 11 May 2016 14:15:54 +0900 Subject: [PATCH] Deduplicate and generalise metric update! MXNet allows for the design of networks that use the same label for multiple outputs. Instead of failing for these kinds of networks, warn the user and try to proceed. --- src/metric.jl | 40 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/src/metric.jl b/src/metric.jl index 7916d45b639c..7e76b969d0a0 100644 --- a/src/metric.jl +++ b/src/metric.jl @@ -36,6 +36,19 @@ set. =# abstract AbstractEvalMetric +# Generic update! version +function update!{T <: AbstractEvalMetric}(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray}) + if length(labels) != length(preds) + Base.warn_once( + "The number of labels ($(length(labels))) does not correspond to the\ + number of outputs ($(length(preds))). The calculated metric might not be accuracte.") + end + for (label, pred) in zip(labels, preds) + _update_single_output(metric, label, pred) + end +end + + #=doc .. class:: Accuracy @@ -85,13 +98,6 @@ function _update_single_output(metric :: Accuracy, label :: NDArray, pred :: NDA end end -function update!(metric :: Accuracy, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - import Base: get function get(metric :: Accuracy) return [(:accuracy, metric.acc_sum / metric.n_sample)] @@ -129,13 +135,6 @@ function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray) end end -function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - function get(metric :: MSE) return [(:MSE, metric.mse_sum / metric.n_sample)] end @@ -193,13 +192,6 @@ function _update_single_output(metric :: ACE, label :: NDArray, pred :: NDArray) end end -function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end - #=doc .. class:: MultiACE @@ -251,9 +243,3 @@ function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDA end end -function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray}) - @assert length(labels) == length(preds) - for i = 1:length(labels) - _update_single_output(metric, labels[i], preds[i]) - end -end