From a74ebce12a7de82be1b39e63aaf53bb807679b1f Mon Sep 17 00:00:00 2001 From: Carin Meier Date: Wed, 12 Sep 2018 20:05:56 -0400 Subject: [PATCH] MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package (#12387) * Bring clojure package inline with new DataDesc and Layout in Scala package * formatting cljfmt * revert the implementation of module fit back now that DataDesc issue if fixed - update Module example to use provide-data-desc and provide-label-desc * update to provide-data-desc and provide-label-desc * decrease epochs to speed example * Add tests and docstrings * remove let --- .../examples/gan/src/gan/gan_mnist.clj | 6 +- .../examples/module/src/mnist_mlp.clj | 2 +- .../src/pre_trained_models/fine_tune.clj | 2 +- .../src/pre_trained_models/predict_image.clj | 2 +- .../examples/rnn/src/rnn/train_char_rnn.clj | 12 +-- .../src/org/apache/clojure_mxnet/io.clj | 88 +++++++++++++------ .../src/org/apache/clojure_mxnet/layout.clj | 35 ++++++++ .../src/org/apache/clojure_mxnet/module.clj | 58 ++---------- .../src/org/apache/clojure_mxnet/ndarray.clj | 2 +- .../src/org/apache/clojure_mxnet/symbol.clj | 2 +- .../test/org/apache/clojure_mxnet/io_test.clj | 53 ++++++++++- .../org/apache/clojure_mxnet/module_test.clj | 9 +- .../org/apache/clojure_mxnet/test_util.clj | 6 +- 13 files changed, 178 insertions(+), 99 deletions(-) create mode 100644 contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj diff --git a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj index 14dd2c5cc3f7..e2e3364535ec 100644 --- a/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj +++ b/contrib/clojure-package/examples/gan/src/gan/gan_mnist.clj @@ -159,13 +159,13 @@ (defn train [devs] (let [mod-d (-> (m/module (discriminator) {:contexts devs :data-names ["data"] :label-names ["label"]}) - (m/bind {:data-shapes (mx-io/provide-data mnist-iter) - :label-shapes (mx-io/provide-label mnist-iter) + (m/bind {:data-shapes (mx-io/provide-data-desc mnist-iter) + :label-shapes (mx-io/provide-label-desc mnist-iter) :inputs-need-grad true}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})})) mod-g (-> (m/module (generator) {:contexts devs :data-names ["rand"] :label-names nil}) - (m/bind {:data-shapes (mx-io/provide-data rand-noise-iter)}) + (m/bind {:data-shapes (mx-io/provide-data-desc rand-noise-iter)}) (m/init-params {:initializer (init/normal 0.02)}) (m/init-optimizer {:optimizer (opt/adam {:learning-rate lr :wd 0.0 :beta1 beta1})}))] diff --git a/contrib/clojure-package/examples/module/src/mnist_mlp.clj b/contrib/clojure-package/examples/module/src/mnist_mlp.clj index 74edf71172c7..c5ffbbede852 100644 --- a/contrib/clojure-package/examples/module/src/mnist_mlp.clj +++ b/contrib/clojure-package/examples/module/src/mnist_mlp.clj @@ -85,7 +85,7 @@ (m/module (get-symbol) {:contexts devs})) metric (eval-metric/accuracy)] (-> mod - (m/bind {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)}) + (m/bind {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)}) (m/init-params) (m/init-optimizer {:optimizer (optimizer/sgd {:learning-rate 0.01 :momentum 0.9})})) diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj index f2b9eddeb2af..93c121f9fc16 100644 --- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj +++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/fine_tune.clj @@ -80,7 +80,7 @@ (defn fit [devs msymbol arg-params aux-params] (let [mod (-> (m/module msymbol {:contexts devs}) - (m/bind {:data-shapes (mx-io/provide-data train-iter) :label-shapes (mx-io/provide-label val-iter)}) + (m/bind {:data-shapes (mx-io/provide-data-desc train-iter) :label-shapes (mx-io/provide-label-desc val-iter)}) (m/init-params {:arg-params arg-params :aux-params aux-params :allow-missing true}))] (m/fit mod diff --git a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj index 12bdb12fb5ac..71202bc000f9 100644 --- a/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj +++ b/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj @@ -92,7 +92,7 @@ (comment - (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg") + (predict "https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/python/predict_image/cat.jpg" true) ;; ({:prob 0.69066674, :label "n02122948 kitten, kitty"} ;; {:prob 0.04466057, :label "n01323155 kit"} ;; {:prob 0.029682875, :label "n01318894 pet"} diff --git a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj index 29aba26b1951..150cd94e673c 100644 --- a/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj +++ b/contrib/clojure-package/examples/rnn/src/rnn/train_char_rnn.clj @@ -109,16 +109,16 @@ :label-name "softmax_label" :data-batch-size batch-size :last-batch-handle "pad"}) - data-and-labels (merge (data-desc->map (mx-io/provide-data train-iter)) - (data-desc->map (mx-io/provide-label train-iter)) + data-and-labels (merge (data-desc->map (mx-io/provide-data-desc train-iter)) + (data-desc->map (mx-io/provide-label-desc train-iter)) init-states) init-states-data (mapv (fn [[k v]] (ndarray/zeros v {:ctx ctx})) init-states) rnn-sym (sym-gen (first buckets)) rnn-mod (-> (m/module rnn-sym {:contexts devs}) - (m/bind {:data-shapes (into (mx-io/provide-data train-iter) + (m/bind {:data-shapes (into (mx-io/provide-data-desc train-iter) (mapv (fn [[k v]] {:name k :shape v}) init-states)) - :label-shapes (mx-io/provide-label train-iter)}) + :label-shapes (mx-io/provide-label-desc train-iter)}) (m/init-params {:initializer (init/xavier {:factor-type "in" :magnitude 2.34})}) (m/init-optimizer {:optimizer (optimizer/adam {:learning-rate learning-rate :wd 0.0001})})) metric (eval-metric/custom-metric @@ -141,8 +141,8 @@ "perplexity")] - ;; Train for 2 epochs and then show the results of 75 - (doseq [epoch-num (range 2)] + ;; Train for 1 epochs and then show the results of 75 + (doseq [epoch-num (range 1)] (println "Doing epoch " epoch-num) (mx-io/reduce-batches train-iter diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj index d6f1499ba829..a2b639934f49 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/io.clj @@ -17,11 +17,12 @@ (ns org.apache.clojure-mxnet.io (:refer-clojure :exclude [next]) - (:require [org.apache.clojure-mxnet.base :as base] + (:require [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.base :as base] [org.apache.clojure-mxnet.shape :as mx-shape] [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.dtype :as dtype] - [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.random :as random]) (:import (org.apache.mxnet IO DataDesc DataBatch NDArray) @@ -57,18 +58,48 @@ (defn resize-iter [iter nbatch]) -(defn provide-data [pack-iterator] +(defn provide-data + "Provides the description of the data iterator in the form of + [{:name name :shape shape-vec}]" + [pack-iterator] (->> pack-iterator (.provideData) (util/scala-map->map) (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) -(defn provide-label [pack-iterator] +(defn provide-label + "Provides the description of the label iterator in the form of + [{:name name :shape shape-vec}]" + [pack-iterator] (->> pack-iterator (.provideLabel) (util/scala-map->map) (mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) +(defn data-desc->map [dd] + {:name (.name dd) + :shape (mx-shape/->vec (.shape dd)) + :dtype (.dtype dd) + :layout (.layout dd)}) + +(defn provide-data-desc + "Provides the Data Desc of the data iterator in the form of + [{:name name :shape shape-vec :dtype dtype :layout layout}]" + [pack-iterator] + (->> pack-iterator + (.provideDataDesc) + (util/scala-vector->vec) + (mapv data-desc->map))) + +(defn provide-label-desc + "Provides the Data Desc of the label iterator in the form of + [{:name name :shape shape-vec :dtype dtype :layout layout}]" + [pack-iterator] + (->> pack-iterator + (.provideLabelDesc) + (util/scala-vector->vec) + (mapv data-desc->map))) + (defn reset [iterator] (.reset iterator)) @@ -194,7 +225,8 @@ (defn ndarray-iter " * NDArrayIter object in mxnet. Taking NDArray to get dataiter. * - * @param data vector of iter + * @param data vector of iter - Can either by in the form for [ndarray..] or + * {data-desc0 ndarray0 data-desc2 ndarray2 ...} * @opts map of: * :label Same as data, but is not fed to the model during testing. * :data-batch-size Batch Size (default 1) @@ -213,14 +245,23 @@ last-batch-handle "pad" data-name "data" label-name "label"}}] - (new NDArrayIter - (util/vec->indexed-seq data) - (if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) - (int data-batch-size) - shuffle - last-batch-handle - data-name - label-name)) + (if (map? data) + (new NDArrayIter + (.toIndexedSeq (util/list-map data)) + (if label + (.toIndexedSeq (util/list-map label)) + (util/empty-indexed-seq)) + (int data-batch-size) + shuffle + last-batch-handle) + (new NDArrayIter + (util/vec->indexed-seq data) + (if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) + (int data-batch-size) + shuffle + last-batch-handle + data-name + label-name))) ([data] (ndarray-iter data {}))) @@ -230,24 +271,19 @@ (s/def ::name string?) (s/def ::shape vector?) (s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64}) +(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED + layout/NCHW + layout/NTC + layout/NT + layout/N})) (s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout])) -;; NCHW is N:batch size C: channel H: height W: width -;;; other layouts are -;; NT, TNC, nad N -;; the shape length must match the lengh of the layout string size (defn data-desc ([{:keys [name shape dtype layout] :as opts - :or {dtype base/MX_REAL_TYPE}}] + :or {dtype base/MX_REAL_TYPE + layout layout/UNDEFINED}}] (util/validate! ::data-desc opts "Invalid data description") - (let [sc (count shape) - layout (or layout (cond - (= 1 sc) "N" - (= 2 sc) "NT" - (= 3 sc) "TNC" - (= 4 sc) "NCHW" - :else (apply str (repeat sc "?"))))] - (new DataDesc name (mx-shape/->shape shape) dtype layout))) + (new DataDesc name (mx-shape/->shape shape) dtype layout)) ([name shape] (data-desc {:name name :shape shape}))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj new file mode 100644 index 000000000000..f379a7a02d28 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/layout.clj @@ -0,0 +1,35 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.layout + (:import (org.apache.mxnet Layout))) + +;; +;; Layout definition of DataDesc +;; N Batch size +;; C channels +;; H Height +;; W Weight +;; T sequence length +;; __undefined__ default value of Layout +;; + +(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__" +(def NCHW (Layout/NCHW)) ;=> "NCHW" +(def NTC (Layout/NTC)) ;=> "NTC" +(def NT (Layout/NT)) ;=> "NT" +(def N (Layout/N)) ;=> "N diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj index ab6d345fe91d..aa5ce39f7a80 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/module.clj @@ -88,8 +88,8 @@ can perform computation with the module. mod : module map of opts: - :data-shapes Typically is (provide-data data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout - :label-shapes Typically is (provide-label data-iter). map of :name :shape :dtype and :layout + :data-shapes Typically is (provide-data-desc data-iter). Data shape must be in the form of io/data-desc with is a map of :name :shape :dtype and :layout + :label-shapes Typically is (provide-label-desc data-iter). map of :name :shape :dtype and :layout :for-training Default is `true`. Whether the executors should be bind for training. :inputs-need-grad Default is `false`. Whether the gradients to the input data need to be computed. @@ -547,54 +547,12 @@ `:or {num-epoch 1 fit-params (new FitParams)}}] (util/validate! ::fit-options opts "Invalid options for fit") - (let [fmod (-> mod - (bind {:data-shapes (mx-io/provide-data train-data) - :label-shapes (mx-io/provide-label train-data) - :for-training true - :force-rebind (.forceRebind fit-params)}) - (init-params (remove (fn [[k v]] (nil? v)) - {:initializer (.initializer fit-params) - :arg-params (.argParams fit-params) - :aux-params (.auxParams fit-params) - :allow-missing (.allowMissing fit-params)})) - (init-optimizer (remove (fn [[k v]] (nil? v)) - {:optimizer (.optimizer fit-params) - :kvstore (.kvstore fit-params)}))) - eval-metric (or (.evalMetric fit-params) (eval-metric/accuracy)) - val-metric (or (util/option->value (.validationMetric fit-params)) (eval-metric/accuracy))] - (doseq [i (range num-epoch)] - (let [tic (System/currentTimeMillis)] - (mx-io/reduce-batches train-data - (fn [batch-num batch] - (-> fmod - (forward batch) - (backward) - (update) - (update-metric eval-metric (mx-io/batch-label batch))) - (when-let [cb (util/option->value (.batchEndCallback fit-params))] - (callback/invoke cb i batch-num eval-metric)) - (.dispose batch) - (inc batch-num))) - (println "Epoch " i " Train-" (eval-metric/get eval-metric)) - (println "Epoch " i " Time cost-" (- (System/currentTimeMillis) tic)) - - ;;sync across kvstores - (get-params fmod) - (when-let [cb (util/option->value (.epochEndCallback fit-params))] - (callback/invoke cb i 0 val-metric)) - - ;; evaluation on the validation set - (when eval-data - (let [res (score fmod {:eval-data eval-data :eval-metric eval-metric :epoch i})] - (println "Epoch " i " Validation- " res))))) - fmod) - ;; old way if the problem with the sizes get resolved in DataDesc - #_(doto mod - (.fit - train-data - (util/->option eval-data) - (int num-epoch) - fit-params))) + (doto mod + (.fit + train-data + (util/->option eval-data) + (int num-epoch) + fit-params))) (s/def ::eval-data ::train-data) (s/def ::num-batch integer?) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj index 7ca4ede9733c..e37a8bc8c98d 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj @@ -89,7 +89,7 @@ (NDArray/arange (float start) ($/option (float stop)) step repeat ctx dtype)) ([start stop] (arange start stop {}))) - + (defn slice "Return a sliced NDArray that shares memory with current one." ([ndarray i] diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj index 12135fb75cab..58b1d6d49fff 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol.clj @@ -144,7 +144,7 @@ which must be known from the rest of the net." ([start {:keys [step repeat dtype] :or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} - :as opts}] + :as opts}] (Symbol/arange (float start) ($/option nil) step repeat true nil dtype)) ([start] (arange-with-inference start {}))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj index ace39ec201eb..9babf1e22536 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/io_test.clj @@ -22,7 +22,9 @@ [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.util :as util] [org.apache.clojure-mxnet.shape :as mx-shape] - [clojure.test :refer :all])) + [clojure.test :refer :all] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.layout :as layout])) (deftest test-mnsit-iter-and-mnist-pack (let [_ (when-not (.exists (io/file "data/train-images-idx3-ubyte")) @@ -59,6 +61,31 @@ (is (= label1 label0)) (is (= data1 data0)))))) +(deftest test-provide-data-and-label + (let [test-data (mx-io/mnist-iter {:image "data/train-images-idx3-ubyte" + :label "data/train-labels-idx1-ubyte" + :label-name "softmax_label" + :data-shape [1 28 28] + :label-shape [1 1 10] + :batch-size 100 + :shuffle true + :flat false + :silent false + :seed 10})] + (is (= [{:name "data", :shape [100 1 28 28]}] + (mx-io/provide-data test-data))) + (is (= [{:name "softmax_label", :shape [100]}] + (mx-io/provide-label test-data))) + (is (= [{:name "data", :shape [100 1 28 28] + :dtype dtype/FLOAT32 + :layout layout/UNDEFINED}] + (mx-io/provide-data-desc test-data))) + (is (= [{:name "softmax_label" + :shape [100] + :dtype dtype/FLOAT32 + :layout layout/UNDEFINED}] + (mx-io/provide-label-desc test-data))))) + (deftest test-image-record-iter (let [_ (when-not (.exists (io/file "data/cifar/train.rec")) (sh "scripts/get_cifar_data.sh")) @@ -162,4 +189,26 @@ :last-batch-handle "discard"}) nbatch2 7] (is (= nbatch2 (mx-io/reduce-batches data-iter2 (fn [result batch] (inc result))))) - (is (= [] (mx-io/iter-init-label data-iter2)))))) + (is (= [] (mx-io/iter-init-label data-iter2)))) + + ;;; testing with a specified layout + (let [label-desc (mx-io/data-desc {:name "label" + :shape [2 2] + :dtype dtype/INT32 + :layout layout/NT}) + data-desc (mx-io/data-desc {:name "data" + :shape [2 2 2] + :dtype dtype/FLOAT32 + :layout layout/NTC}) + label (ndarray/ones [2 2] {:dtype dtype/INT32}) + data (ndarray/ones [2 2 2] {:dtype dtype/FLOAT32}) + data-iter3 (mx-io/ndarray-iter {data-desc data} + {:label {label-desc label}})] + (is (= {:dtype dtype/FLOAT32 :layout layout/NTC} + (-> (mx-io/provide-data-desc data-iter3) + first + (select-keys [:dtype :layout])))) + (is (= {:dtype dtype/INT32 :layout layout/NT} + (-> (mx-io/provide-label-desc data-iter3) + first + (select-keys [:dtype :layout]))))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj index 0f71b5a850cc..d53af2ec249d 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/module_test.clj @@ -20,6 +20,7 @@ [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.module :as m] [org.apache.clojure-mxnet.monitor :as monitor] [org.apache.clojure-mxnet.ndarray :as ndarray] @@ -54,9 +55,9 @@ c (sym/+ a (sym/+ (sym/* b 2) (sym/* c 3))) mod (m/module c ["b" "c" "a"] nil [(context/cpu 0) (context/cpu 1)])] (-> mod - (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout "NT"} - {:name "c" :shape [5 5] :layout "NT"} - {:name "a" :shape [5 5] :layout "NT"}] + (m/bind {:data-shapes [{:name "b" :shape [5 5] :layout layout/NT} + {:name "c" :shape [5 5] :layout layout/NT} + {:name "a" :shape [5 5] :layout layout/NT}] :inputs-need-grad true}) (m/init-params) (m/forward {:data [(ndarray/ones [5 5]) @@ -172,7 +173,7 @@ (sym/linear-regression-output "softmax" {:data v :grad-scale 2})) mod (m/module x)] - (m/bind mod {:data-shapes (mx-io/provide-data train-data) :label-shapes (mx-io/provide-label train-data)}) + (m/bind mod {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label train-data)}) (let [arg-params-correct {"fc_0_weight" (ndarray/array [0.15 0.2 0.25 0.3] [2 2]) "fc_0_bias" (ndarray/array [0.35 0.35] [2]) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj index ecd54ca72773..d632c969eae9 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/test_util.clj @@ -23,7 +23,7 @@ (let [diff (Math/abs (- x y))] (< diff tolerance)) (and - (= (count x) (count y)) - (reduce (fn [x y] (and x y)) - (map #(approx= tolerance %1 %2) x y))))) + (= (count x) (count y)) + (reduce (fn [x y] (and x y)) + (map #(approx= tolerance %1 %2) x y)))))