Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixed param coercion of clojure executor/forward (#12627) (#12630)
Browse files Browse the repository at this point in the history
  • Loading branch information
paroda authored and gigasquid committed Sep 21, 2018
1 parent 504d24c commit 846bda4
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
(do (.forward executor)
executor))
([executor is-train kwargs]
(do (.forward executor is-train (util/nil-or-coerce-param kwargs #{"scala.collection.immutable.Map"})))))
(do (.forward executor is-train (util/map->scala-tuple-seq kwargs))
executor)))

(defn backward
"* Do backward pass to get the gradient of arguments.
Expand Down
21 changes: 21 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,24 @@
(throw (ex-info error-msg
(s/explain-data spec value)))))

(defn map->scala-tuple-seq
"* Convert a map to a scala-Seq of scala-Tubple.
* Should also work if a seq of seq of 2 things passed.
* Otherwise passed through unchanged."
[map-or-tuple-seq]
(letfn [(key->name [k]
(if (or (keyword? k) (string? k) (symbol? k))
(string/replace (name k) "-" "_")
k))
(->tuple [kvp-or-tuple]
(if (coll? kvp-or-tuple)
(let [[k v] kvp-or-tuple]
($/tuple (key->name k) v))
;; pass-through
kvp-or-tuple))]
(if (coll? map-or-tuple-seq)
(->> map-or-tuple-seq
(map ->tuple)
(apply $/immutable-list))
;; pass-through
map-or-tuple-seq)))
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,23 @@
(is (every? #(= 4.0 %) (->> (executor/outputs exec)
(map ndarray/->vec)
first)))))

(deftest test-forward
(let [a (sym/variable "a")
b (sym/variable "b")
c (sym/+ a b)
ex (sym/bind c {:a (ndarray/* (ndarray/ones [1 2]) 2)
:b (ndarray/* (ndarray/ones [1 2]) 3)})]
;; test forward with binded values
(executor/forward ex)
(is (= [5.0 5.0] (-> ex executor/outputs first ndarray/->vec)))
;; test forward with new a (b is still [3.0 3.0]
(executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 4)})
(is (= [7.0 7.0] (-> ex executor/outputs first ndarray/->vec)))
;; test forward with new b (a is still [4.0 4.0]
(executor/forward ex false {:b (ndarray/* (ndarray/ones [1 2]) 5)})
(is (= [9.0 9.0] (-> ex executor/outputs first ndarray/->vec)))
;; test forward with new a & b
(executor/forward ex false {:a (ndarray/* (ndarray/ones [1 2]) 6)
:b (ndarray/* (ndarray/ones [1 2]) 7)})
(is (= [13.0 13.0] (-> ex executor/outputs first ndarray/->vec)))))
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,19 @@
data2 [1 1 1 1 9 9 9 9]
data3 [1 1 1 2]]
(is (not (test-util/approx= 1e-9 data1 data2)))
(is (test-util/approx= 2 data1 data3))))
(is (test-util/approx= 2 data1 data3))))

(deftest test-map->scala-tuple-seq
;; convert as much, and pass-through the rest
(is (nil? (util/map->scala-tuple-seq nil)))
(is (= "List()"
(str (util/map->scala-tuple-seq {}))
(str (util/map->scala-tuple-seq []))
(str (util/map->scala-tuple-seq '()))))
(is (= "List(a, b)" (str (util/map->scala-tuple-seq ["a" "b"]))))
(is (= "List((a,b), (c,d), (e,f), (a_b,g), (c_d,h), (e_f,i))"
(str (util/map->scala-tuple-seq {:a "b", 'c "d", "e" "f"
:a-b "g", 'c-d "h", "e-f" "i"}))))
(let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})]
(is (= "a_b" (._1 (.head nda))))
(is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda)))))))

0 comments on commit 846bda4

Please sign in to comment.