Skip to content

Commit

Permalink
Deduplicate and generalise metric update!
Browse files Browse the repository at this point in the history
MXNet allows for the design of networks that use the same label
for multiple outputs. Instead of failing for these kinds of
networks, warn the user and try to proceed.
  • Loading branch information
vchuravy committed May 11, 2016
1 parent baa9c2a commit 64a63f7
Showing 1 changed file with 13 additions and 27 deletions.
40 changes: 13 additions & 27 deletions src/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ set.
=#
abstract AbstractEvalMetric

# Generic update! version
function update!{T <: AbstractEvalMetric}(metric :: T, labels :: Vector{NDArray}, preds :: Vector{NDArray})
if length(labels) != length(preds)
Base.warn_once(
"The number of labels ($(length(labels))) does not correspond to the\
number of outputs ($(length(preds))). The calculated metric might not be accuracte.")
end
for (label, pred) in zip(labels, preds)
_update_single_output(metric, label, pred)
end
end


#=doc
.. class:: Accuracy
Expand Down Expand Up @@ -85,13 +98,6 @@ function _update_single_output(metric :: Accuracy, label :: NDArray, pred :: NDA
end
end

function update!(metric :: Accuracy, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

import Base: get
function get(metric :: Accuracy)
return [(:accuracy, metric.acc_sum / metric.n_sample)]
Expand Down Expand Up @@ -129,13 +135,6 @@ function _update_single_output(metric :: MSE, label :: NDArray, pred :: NDArray)
end
end

function update!(metric :: MSE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

function get(metric :: MSE)
return [(:MSE, metric.mse_sum / metric.n_sample)]
end
Expand Down Expand Up @@ -193,13 +192,6 @@ function _update_single_output(metric :: ACE, label :: NDArray, pred :: NDArray)
end
end

function update!(metric :: ACE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

#=doc
.. class:: MultiACE
Expand Down Expand Up @@ -251,9 +243,3 @@ function _update_single_output(metric :: MultiACE, label :: NDArray, pred :: NDA
end
end

function update!(metric :: MultiACE, labels :: Vector{NDArray}, preds :: Vector{NDArray})
@assert length(labels) == length(preds)
for i = 1:length(labels)
_update_single_output(metric, labels[i], preds[i])
end
end

0 comments on commit 64a63f7

Please sign in to comment.