Skip to content

Commit

Permalink
MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout…
Browse files Browse the repository at this point in the history
… in Scala Package (apache#12387)

* Bring clojure package inline with new DataDesc and Layout in Scala package

* formatting cljfmt

* revert the implementation of module fit back now that DataDesc issue if fixed
- update Module example to use provide-data-desc and provide-label-desc

* update to provide-data-desc and provide-label-desc

* decrease epochs to speed example

* Add tests and docstrings

* remove let
  • Loading branch information
gigasquid authored and anirudh2290 committed Sep 19, 2018
1 parent c960493 commit a74ebce
Show file tree
Hide file tree
Showing 13 changed files with 178 additions and 99 deletions.
6 changes: 3 additions & 3 deletions contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@

(defn train [devs]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
(m/bind {:data-shapes (mx-io/provide-data mnist-iter)
:label-shapes (mx-io/provide-label mnist-iter)
(m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
:label-shapes (mx-io/provide-label-desc mnist-iter)
:inputs-need-grad true})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
(m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
(m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]

Expand Down
2 changes: 1 addition & 1 deletion contrib/clojure-package/examples/module/src/mnist_mlp.clj
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
(m/module (get-symbol) {:contexts devs}))
metric (eval-metric/accuracy)]
(-> mod
(m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
(m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)})
(m/init-params)
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

(defn fit [devs msymbol arg-params aux-params]
(let [mod (-> (m/module msymbol {:contexts devs})
(m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)})
(m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)})
(m/init-params {:arg-params arg-params :aux-params aux-params
:allow-missing true}))]
(m/fit mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

(comment

(predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg")
(predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true)
;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"}
;; {:prob 0.04466057, :label "n01323155 kit"}
;; {:prob 0.029682875, :label "n01318894 pet"}
Expand Down
12 changes: 6 additions & 6 deletions contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,16 @@
:label-name "softmax_label"
:data-batch-size batch-size
:last-batch-handle "pad"})
data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter))
(data-desc->map (mx-io/provide-label train-iter))
data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter))
(data-desc->map (mx-io/provide-label-desc train-iter))
init-states)
init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states)
rnn-sym (sym-gen (first buckets))

rnn-mod (-> (m/module rnn-sym {:contexts devs})
(m/bind {:data-shapes (into (mx-io/provide-data train-iter)
(m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter)
(mapv (fn [[k v]] {:name k :shape v}) init-states))
:label-shapes (mx-io/provide-label train-iter)})
:label-shapes (mx-io/provide-label-desc train-iter)})
(m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})})
(m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})}))
metric (eval-metric/custom-metric
Expand All @@ -141,8 +141,8 @@

"perplexity")]

;; Train for 2 epochs and then show the results of 75
(doseq [epoch-num (range 2)]
;; Train for 1 epochs and then show the results of 75
(doseq [epoch-num (range 1)]
(println "Doing epoch " epoch-num)
(mx-io/reduce-batches
train-iter
Expand Down
88 changes: 62 additions & 26 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

(ns org.apache.clojure-mxnet.io
(:refer-clojure :exclude [next])
(:require [org.apache.clojure-mxnet.base :as base]
(:require [clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.dtype :as dtype]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.random :as random])
(:import (org.apache.mxnet IO DataDesc DataBatch NDArray)
Expand Down Expand Up @@ -57,18 +58,48 @@

(defn resize-iter [iter nbatch])

(defn provide-data [pack-iterator]
(defn provide-data
"Provides the description of the data iterator in the form of
[{:name name :shape shape-vec}]"
[pack-iterator]
(->> pack-iterator
(.provideData)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))

(defn provide-label [pack-iterator]
(defn provide-label
"Provides the description of the label iterator in the form of
[{:name name :shape shape-vec}]"
[pack-iterator]
(->> pack-iterator
(.provideLabel)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))

(defn data-desc->map [dd]
{:name (.name dd)
:shape (mx-shape/->vec (.shape dd))
:dtype (.dtype dd)
:layout (.layout dd)})

(defn provide-data-desc
"Provides the Data Desc of the data iterator in the form of
[{:name name :shape shape-vec :dtype dtype :layout layout}]"
[pack-iterator]
(->> pack-iterator
(.provideDataDesc)
(util/scala-vector->vec)
(mapv data-desc->map)))

(defn provide-label-desc
"Provides the Data Desc of the label iterator in the form of
[{:name name :shape shape-vec :dtype dtype :layout layout}]"
[pack-iterator]
(->> pack-iterator
(.provideLabelDesc)
(util/scala-vector->vec)
(mapv data-desc->map)))

(defn reset [iterator]
(.reset iterator))

Expand Down Expand Up @@ -194,7 +225,8 @@
(defn ndarray-iter
" * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
* @param data vector of iter
* @param data vector of iter - Can either by in the form for [ndarray..] or
* {data-desc0 ndarray0 data-desc2 ndarray2 ...}
* @opts map of:
* :label Same as data, but is not fed to the model during testing.
* :data-batch-size Batch Size (default 1)
Expand All @@ -213,14 +245,23 @@
last-batch-handle "pad"
data-name "data"
label-name "label"}}]
(new NDArrayIter
(util/vec->indexed-seq data)
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle
data-name
label-name))
(if (map? data)
(new NDArrayIter
(.toIndexedSeq (util/list-map data))
(if label
(.toIndexedSeq (util/list-map label))
(util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle)
(new NDArrayIter
(util/vec->indexed-seq data)
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle
data-name
label-name)))
([data]
(ndarray-iter data {})))

Expand All @@ -230,24 +271,19 @@
(s/def ::name string?)
(s/def ::shape vector?)
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED
layout/NCHW
layout/NTC
layout/NT
layout/N}))
(s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout]))

;; NCHW is N:batch size C: channel H: height W: width
;;; other layouts are
;; NT, TNC, nad N
;; the shape length must match the lengh of the layout string size
(defn data-desc
([{:keys [name shape dtype layout] :as opts
:or {dtype base/MX_REAL_TYPE}}]
:or {dtype base/MX_REAL_TYPE
layout layout/UNDEFINED}}]
(util/validate! ::data-desc opts "Invalid data description")
(let [sc (count shape)
layout (or layout (cond
(= 1 sc) "N"
(= 2 sc) "NT"
(= 3 sc) "TNC"
(= 4 sc) "NCHW"
:else (apply str (repeat sc "?"))))]
(new DataDesc name (mx-shape/->shape shape) dtype layout)))
(new DataDesc name (mx-shape/->shape shape) dtype layout))
([name shape]
(data-desc {:name name :shape shape})))

Expand Down
35 changes: 35 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.layout
(:import (org.apache.mxnet Layout)))

;;
;; Layout definition of DataDesc
;; N Batch size
;; C channels
;; H Height
;; W Weight
;; T sequence length
;; __undefined__ default value of Layout
;;

(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__"
(def NCHW (Layout/NCHW)) ;=> "NCHW"
(def NTC (Layout/NTC)) ;=> "NTC"
(def NT (Layout/NT)) ;=> "NT"
(def N (Layout/N)) ;=> "N
58 changes: 8 additions & 50 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
can perform computation with the module.
mod : module
map of opts:
:data-shapes Typically is (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
:label-shapes Typically is (provide-label data-iter). map of :name :shape :dtype and :layout
:data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
:label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout
:for-training Default is `true`. Whether the executors should be bind for training.
:inputs-need-grad Default is `false`.
Whether the gradients to the input data need to be computed.
Expand Down Expand Up @@ -547,54 +547,12 @@
`:or {num-epoch 1
fit-params (new FitParams)}}]
(util/validate! ::fit-options opts "Invalid options for fit")
(let [fmod (-> mod
(bind {:data-shapes (mx-io/provide-data train-data)
:label-shapes (mx-io/provide-label train-data)
:for-training true
:force-rebind (.forceRebind fit-params)})
(init-params (remove (fn [[k v]] (nil? v))
{:initializer (.initializer fit-params)
:arg-params (.argParams fit-params)
:aux-params (.auxParams fit-params)
:allow-missing (.allowMissing fit-params)}))
(init-optimizer (remove (fn [[k v]] (nil? v))
{:optimizer (.optimizer fit-params)
:kvstore (.kvstore fit-params)})))
eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy))
val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))]
(doseq [i (range num-epoch)]
(let [tic (System/currentTimeMillis)]
(mx-io/reduce-batches train-data
(fn [batch-num batch]
(-> fmod
(forward batch)
(backward)
(update)
(update-metric eval-metric (mx-io/batch-label batch)))
(when-let [cb (util/option->value (.batchEndCallback fit-params))]
(callback/invoke cb i batch-num eval-metric))
(.dispose batch)
(inc batch-num)))
(println "Epoch " i " Train-" (eval-metric/get eval-metric))
(println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic))

;;sync across kvstores
(get-params fmod)
(when-let [cb (util/option->value (.epochEndCallback fit-params))]
(callback/invoke cb i 0 val-metric))

;; evaluation on the validation set
(when eval-data
(let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})]
(println "Epoch " i " Validation- " res)))))
fmod)
;; old way if the problem with the sizes get resolved in DataDesc
#_(doto mod
(.fit
train-data
(util/->option eval-data)
(int num-epoch)
fit-params)))
(doto mod
(.fit
train-data
(util/->option eval-data)
(int num-epoch)
fit-params)))

(s/def ::eval-data ::train-data)
(s/def ::num-batch integer?)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
which must be known from the rest of the net."
([start {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
:as opts}]
(Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
([start]
(arange-with-inference start {})))
Expand Down
Loading

0 comments on commit a74ebce

Please sign in to comment.