Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package #12387

Merged
merged 8 commits into from
Sep 13, 2018
6 changes: 3 additions & 3 deletions contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@

(defn train [devs]
(let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]})
(m/bind {:data-shapes (mx-io/provide-data mnist-iter)
:label-shapes (mx-io/provide-label mnist-iter)
(m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter)
:label-shapes (mx-io/provide-label-desc mnist-iter)
:inputs-need-grad true})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))
mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil})
(m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)})
(m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)})
(m/init-params {:initializer (init/normal 0.02)})
(m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@
(m/module (get-symbol) {:contexts devs}))
metric (eval-metric/accuracy)]
(-> mod
(m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)})
(m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)})
(m/init-params)
(m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})}))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@

(defn fit [devs msymbol arg-params aux-params]
(let [mod (-> (m/module msymbol {:contexts devs})
(m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)})
(m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)})
(m/init-params {:arg-params arg-params :aux-params aux-params
:allow-missing true}))]
(m/fit mod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@

(comment

(predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg")
(predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this true mean in here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not a big deal since it's a helper fn in an example, but it'd make sense to use an options map or kw arg ({:display true} or :display true). [outside this PR though, also]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agree

;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"}
;; {:prob 0.04466057, :label "n01323155 kit"}
;; {:prob 0.029682875, :label "n01318894 pet"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,16 +109,16 @@
:label-name "softmax_label"
:data-batch-size batch-size
:last-batch-handle "pad"})
data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter))
(data-desc->map (mx-io/provide-label train-iter))
data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter))
(data-desc->map (mx-io/provide-label-desc train-iter))
init-states)
init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states)
rnn-sym (sym-gen (first buckets))

rnn-mod (-> (m/module rnn-sym {:contexts devs})
(m/bind {:data-shapes (into (mx-io/provide-data train-iter)
(m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter)
(mapv (fn [[k v]] {:name k :shape v}) init-states))
:label-shapes (mx-io/provide-label train-iter)})
:label-shapes (mx-io/provide-label-desc train-iter)})
(m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})})
(m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})}))
metric (eval-metric/custom-metric
Expand All @@ -141,8 +141,8 @@

"perplexity")]

;; Train for 2 epochs and then show the results of 75
(doseq [epoch-num (range 2)]
;; Train for 1 epochs and then show the results of 75
(doseq [epoch-num (range 1)]
(println "Doing epoch " epoch-num)
(mx-io/reduce-batches
train-iter
Expand Down
88 changes: 62 additions & 26 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@

(ns org.apache.clojure-mxnet.io
(:refer-clojure :exclude [next])
(:require [org.apache.clojure-mxnet.base :as base]
(:require [clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.dtype :as dtype]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.layout :as layout]
[org.apache.clojure-mxnet.ndarray :as ndarray]
[org.apache.clojure-mxnet.random :as random])
(:import (org.apache.mxnet IO DataDesc DataBatch NDArray)
Expand Down Expand Up @@ -57,18 +58,48 @@

(defn resize-iter [iter nbatch])

(defn provide-data [pack-iterator]
(defn provide-data
"Provides the description of the data iterator in the form of
[{:name name :shape shape-vec}]"
[pack-iterator]
(->> pack-iterator
(.provideData)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))

(defn provide-label [pack-iterator]
(defn provide-label
"Provides the description of the label iterator in the form of
[{:name name :shape shape-vec}]"
[pack-iterator]
(->> pack-iterator
(.provideLabel)
(util/scala-map->map)
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)}))))

(defn data-desc->map [dd]
{:name (.name dd)
:shape (mx-shape/->vec (.shape dd))
:dtype (.dtype dd)
:layout (.layout dd)})

(defn provide-data-desc
"Provides the Data Desc of the data iterator in the form of
[{:name name :shape shape-vec :dtype dtype :layout layout}]"
[pack-iterator]
(->> pack-iterator
(.provideDataDesc)
(util/scala-vector->vec)
(mapv data-desc->map)))

(defn provide-label-desc
"Provides the Data Desc of the label iterator in the form of
[{:name name :shape shape-vec :dtype dtype :layout layout}]"
[pack-iterator]
(->> pack-iterator
(.provideLabelDesc)
(util/scala-vector->vec)
(mapv data-desc->map)))

(defn reset [iterator]
(.reset iterator))

Expand Down Expand Up @@ -194,7 +225,8 @@
(defn ndarray-iter
" * NDArrayIter object in mxnet. Taking NDArray to get dataiter.
*
* @param data vector of iter
* @param data vector of iter - Can either by in the form for [ndarray..] or
* {data-desc0 ndarray0 data-desc2 ndarray2 ...}
* @opts map of:
* :label Same as data, but is not fed to the model during testing.
* :data-batch-size Batch Size (default 1)
Expand All @@ -213,14 +245,23 @@
last-batch-handle "pad"
data-name "data"
label-name "label"}}]
(new NDArrayIter
(util/vec->indexed-seq data)
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle
data-name
label-name))
(if (map? data)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it mean if data is map? what if it is not? I don't know clojure well, just want to make sure it is intended.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes - it checks to see if the data is a map, if so, it is in the form of having a DataDesc associated with it and will be dispatched to the correct Java function signature and with scala interop. If it is not a map, it will dispatch to the original Java function signature without DataDesc.

The argument checking for the correct data structures can be improved by using core.spec in Clojure. It adds gradual type checking. It is in use in the module api, but it hasn't been added in yet in this namespace. I added a line item in the TODO page for the Clojure package to capture it for later improvement work.

(new NDArrayIter
(.toIndexedSeq (util/list-map data))
(if label
(.toIndexedSeq (util/list-map label))
(util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle)
(new NDArrayIter
(util/vec->indexed-seq data)
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq))
(int data-batch-size)
shuffle
last-batch-handle
data-name
label-name)))
([data]
(ndarray-iter data {})))

Expand All @@ -230,24 +271,19 @@
(s/def ::name string?)
(s/def ::shape vector?)
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64})
(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED
layout/NCHW
layout/NTC
layout/NT
layout/N}))
(s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout]))

;; NCHW is N:batch size C: channel H: height W: width
;;; other layouts are
;; NT, TNC, nad N
;; the shape length must match the lengh of the layout string size
(defn data-desc
([{:keys [name shape dtype layout] :as opts
:or {dtype base/MX_REAL_TYPE}}]
:or {dtype base/MX_REAL_TYPE
layout layout/UNDEFINED}}]
(util/validate! ::data-desc opts "Invalid data description")
(let [sc (count shape)
layout (or layout (cond
(= 1 sc) "N"
(= 2 sc) "NT"
(= 3 sc) "TNC"
(= 4 sc) "NCHW"
:else (apply str (repeat sc "?"))))]
(new DataDesc name (mx-shape/->shape shape) dtype layout)))
(new DataDesc name (mx-shape/->shape shape) dtype layout))
([name shape]
(data-desc {:name name :shape shape})))

Expand Down
35 changes: 35 additions & 0 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
;;
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.layout
(:import (org.apache.mxnet Layout)))

;;
;; Layout definition of DataDesc
;; N Batch size
;; C channels
;; H Height
;; W Weight
;; T sequence length
;; __undefined__ default value of Layout
;;

(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__"
(def NCHW (Layout/NCHW)) ;=> "NCHW"
(def NTC (Layout/NTC)) ;=> "NTC"
(def NT (Layout/NT)) ;=> "NT"
(def N (Layout/N)) ;=> "N
58 changes: 8 additions & 50 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@
can perform computation with the module.
mod : module
map of opts:
:data-shapes Typically is (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
:label-shapes Typically is (provide-label data-iter). map of :name :shape :dtype and :layout
:data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout
:label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout
:for-training Default is `true`. Whether the executors should be bind for training.
:inputs-need-grad Default is `false`.
Whether the gradients to the input data need to be computed.
Expand Down Expand Up @@ -547,54 +547,12 @@
`:or {num-epoch 1
fit-params (new FitParams)}}]
(util/validate! ::fit-options opts "Invalid options for fit")
(let [fmod (-> mod
(bind {:data-shapes (mx-io/provide-data train-data)
:label-shapes (mx-io/provide-label train-data)
:for-training true
:force-rebind (.forceRebind fit-params)})
(init-params (remove (fn [[k v]] (nil? v))
{:initializer (.initializer fit-params)
:arg-params (.argParams fit-params)
:aux-params (.auxParams fit-params)
:allow-missing (.allowMissing fit-params)}))
(init-optimizer (remove (fn [[k v]] (nil? v))
{:optimizer (.optimizer fit-params)
:kvstore (.kvstore fit-params)})))
eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy))
val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))]
(doseq [i (range num-epoch)]
(let [tic (System/currentTimeMillis)]
(mx-io/reduce-batches train-data
(fn [batch-num batch]
(-> fmod
(forward batch)
(backward)
(update)
(update-metric eval-metric (mx-io/batch-label batch)))
(when-let [cb (util/option->value (.batchEndCallback fit-params))]
(callback/invoke cb i batch-num eval-metric))
(.dispose batch)
(inc batch-num)))
(println "Epoch " i " Train-" (eval-metric/get eval-metric))
(println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic))

;;sync across kvstores
(get-params fmod)
(when-let [cb (util/option->value (.epochEndCallback fit-params))]
(callback/invoke cb i 0 val-metric))

;; evaluation on the validation set
(when eval-data
(let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})]
(println "Epoch " i " Validation- " res)))))
fmod)
;; old way if the problem with the sizes get resolved in DataDesc
#_(doto mod
(.fit
train-data
(util/->option eval-data)
(int num-epoch)
fit-params)))
(doto mod
(.fit
train-data
(util/->option eval-data)
(int num-epoch)
fit-params)))

(s/def ::eval-data ::train-data)
(s/def ::num-batch integer?)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
(NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype))
([start stop]
(arange start stop {})))

(defn slice
"Return a sliced NDArray that shares memory with current one."
([ndarray i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
which must be known from the rest of the net."
([start {:keys [step repeat dtype]
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE}
:as opts}]
:as opts}]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space issue...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case the cljfmt tool is fixing my earlier mistake :)

(Symbol/arange (float start) ($/option nil) step repeat true nil dtype))
([start]
(arange-with-inference start {})))
Expand Down
Loading