From e969b9ab9aee5dde6ca0a97744de8998424b6c2e Mon Sep 17 00:00:00 2001 From: pkpa Date: Fri, 21 Sep 2018 15:18:28 +0530 Subject: [PATCH] Fixed param coercion of clojure executor/forward (#12627) --- .../src/org/apache/clojure_mxnet/executor.clj | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj index 4f4155e2d80b..64857b3cb929 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/executor.clj @@ -18,6 +18,8 @@ (ns org.apache.clojure-mxnet.executor (:require [org.apache.clojure-mxnet.util :as util] [clojure.reflect :as r] + [clojure.string :as str] + [t6.from-scala.core :as $] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.shape :as mx-shape])) @@ -26,6 +28,28 @@ (defn ->vec [nd-array] (vec (.toArray nd-array))) +(defn- coerce-map->tuple-seq + "* Convert a map to a scala-Seq of scala-Tubple. + * Should also work if a seq of seq of 2 things passed. + * Otherwise passed through unchanged." + [map-or-tuple-seq] + (letfn [(key->name [k] + (if (or (keyword? k) (string? k) (symbol? k)) + (str/replace (name k) "-" "_") + k)) + (->tuple [kvp-or-tuple] + (if (coll? kvp-or-tuple) + (let [[k v] kvp-or-tuple] + ($/tuple (key->name k) v)) + ;; pass-through + kvp-or-tuple))] + (if (coll? map-or-tuple-seq) + (->> map-or-tuple-seq + (map ->tuple) + (apply $/immutable-list)) + ;; pass-through + map-or-tuple-seq))) + (defn forward "* Calculate the outputs specified by the binded symbol. * @param is-train whether this forward is for evaluation purpose. @@ -34,7 +58,7 @@ (do (.forward executor) executor)) ([executor is-train kwargs] - (do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"}))))) + (do (.forward executor is-train (coerce-map->tuple-seq kwargs))))) (defn backward "* Do backward pass to get the gradient of arguments.