Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixed param coercion of clojure executor/forward (#12627)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkpa committed Sep 21, 2018
1 parent e82eef9 commit e969b9a
Showing 1 changed file with 25 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
(ns org.apache.clojure-mxnet.executor
(:require [org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[clojure.string :as str]
[t6.from-scala.core :as $]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.shape :as mx-shape]))

Expand All @@ -26,6 +28,28 @@
(defn ->vec [nd-array]
(vec (.toArray nd-array)))

(defn- coerce-map->tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
* Should also work if a seq of seq of 2 things passed.
* Otherwise passed through unchanged."
[map-or-tuple-seq]
(letfn [(key->name [k]
(if (or (keyword? k) (string? k) (symbol? k))
(str/replace (name k) "-" "_")
k))
(->tuple [kvp-or-tuple]
(if (coll? kvp-or-tuple)
(let [[k v] kvp-or-tuple]
($/tuple (key->name k) v))
;; pass-through
kvp-or-tuple))]
(if (coll? map-or-tuple-seq)
(->> map-or-tuple-seq
(map ->tuple)
(apply $/immutable-list))
;; pass-through
map-or-tuple-seq)))

(defn forward
"* Calculate the outputs specified by the binded symbol.
* @param is-train whether this forward is for evaluation purpose.
Expand All @@ -34,7 +58,7 @@
(do (.forward executor)
executor))
([executor is-train kwargs]
(do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"})))))
(do (.forward executor is-train (coerce-map->tuple-seq kwargs)))))

(defn backward
"* Do backward pass to get the gradient of arguments.
Expand Down

0 comments on commit e969b9a

Please sign in to comment.