diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj index 1946103a4a2d..f1fe2d18bd35 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/eval_metric.clj @@ -18,7 +18,7 @@ (ns org.apache.clojure-mxnet.eval-metric (:refer-clojure :exclude [get update]) (:require [org.apache.clojure-mxnet.util :as util]) - (:import (org.apache.mxnet Accuracy TopKAccuracy F1 Perplexity MAE MSE RMSE CustomMetric))) + (:import (org.apache.mxnet Accuracy TopKAccuracy F1 Perplexity MAE MSE RMSE CustomMetric CompositeEvalMetric))) (defn accuracy "Basic Accuracy Metric" @@ -74,11 +74,21 @@ [f-eval mname] `(new CustomMetric (util/scala-fn ~f-eval) ~mname)) +(defn comp-metric + "Create a metric instance composed out of several metrics" + [metrics] + (let [cm (CompositeEvalMetric.)] + (doseq [m metrics] (.add cm m)) + cm)) + (defn get - "Get the values of the metric in a vector form (name and value)" + "Get the values of the metric in as a map of {name value} pairs" [metric] - (let [[[mname] [mvalue]] (util/tuple->vec (.get metric))] - [mname mvalue])) + (let [m (apply zipmap (-> (.get metric) + util/tuple->vec))] + (if-not (instance? CompositeEvalMetric metric) + (first m) + m))) (defn reset "clear the internal statistics to an initial state" diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj index d6da2ec9ee58..1f4dba35fa7a 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/eval_metric_test.clj @@ -57,3 +57,12 @@ "my-metric")] (eval-metric/update metric [(ndarray/ones [2])] [(ndarray/ones [2])]) (is (= ["my-metric" 0.0] (eval-metric/get metric))))) + +(deftest test-comp-metric + (let [metric (eval-metric/comp-metric [(eval-metric/accuracy) + (eval-metric/f1) + (eval-metric/top-k-accuracy 2)])] + (eval-metric/update metric [(ndarray/ones [2])] [(ndarray/ones [2 3])]) + (is (= {"accuracy" 0.0 + "f1" 0.0 + "top_k_accuracy" 1.0} (eval-metric/get metric)))))