diff --git a/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj index e61e9ebf6fbb..164b5f2620f2 100644 --- a/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj +++ b/contrib/clojure-package/examples/imclassification/src/imclassification/train_mnist.clj @@ -25,7 +25,8 @@ [org.apache.clojure-mxnet.kvstore :as kvstore] [org.apache.clojure-mxnet.kvstore-server :as kvstore-server] [org.apache.clojure-mxnet.optimizer :as optimizer] - [org.apache.clojure-mxnet.eval-metric :as eval-metric]) + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.resource-scope :as resource-scope]) (:gen-class)) (def data-dir "data/") ;; the data directory to store the mnist data @@ -51,28 +52,6 @@ (when-not (.exists (io/file (str data-dir "train-images-idx3-ubyte"))) (sh "../../scripts/get_mnist_data.sh")) -;;; Load the MNIST datasets -(defonce train-data (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") - :label (str data-dir "train-labels-idx1-ubyte") - :label-name "softmax_label" - :input-shape [784] - :batch-size batch-size - :shuffle true - :flat true - :silent false - :seed 10 - :num-parts num-workers - :part-index 0})) - -(defonce test-data (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") - :label (str data-dir "t10k-labels-idx1-ubyte") - :input-shape [784] - :batch-size batch-size - :flat true - :silent false - :num-parts num-workers - :part-index 0})) - (defn get-symbol [] (as-> (sym/variable "data") data (sym/fully-connected "fc1" {:data data :num-hidden 128}) @@ -82,7 +61,31 @@ (sym/fully-connected "fc3" {:data data :num-hidden 10}) (sym/softmax-output "softmax" {:data data}))) -(defn start + +(defn train-data [] + (mx-io/mnist-iter {:image (str data-dir "train-images-idx3-ubyte") + :label (str data-dir "train-labels-idx1-ubyte") + :label-name "softmax_label" + :input-shape [784] + :batch-size batch-size + :shuffle true + :flat true + :silent false + :seed 10 + :num-parts num-workers + :part-index 0})) + +(defn eval-data [] + (mx-io/mnist-iter {:image (str data-dir "t10k-images-idx3-ubyte") + :label (str data-dir "t10k-labels-idx1-ubyte") + :input-shape [784] + :batch-size batch-size + :flat true + :silent false + :num-parts num-workers + :part-index 0})) + +(defn start ([devs] (start devs num-epoch)) ([devs _num-epoch] (when scheduler-host @@ -96,18 +99,16 @@ (do (println "Starting Training of MNIST ....") (println "Running with context devices of" devs) - (let [_mod (m/module (get-symbol) {:contexts devs})] - (m/fit _mod {:train-data train-data - :eval-data test-data + (resource-scope/with-let [_mod (m/module (get-symbol) {:contexts devs})] + (-> _mod + (m/fit {:train-data (train-data) + :eval-data (eval-data) :num-epoch _num-epoch :fit-params (m/fit-params {:kvstore kvstore :optimizer optimizer :eval-metric eval-metric})}) - (println "Finish fit") - _mod - ) - - )))) + (m/save-checkpoint {:prefix "target/test" :epoch _num-epoch})) + (println "Finish fit")))))) (defn -main [& args] (let [[dev dev-num] args diff --git a/contrib/clojure-package/examples/imclassification/test/imclassification/train_mnist_test.clj b/contrib/clojure-package/examples/imclassification/test/imclassification/train_mnist_test.clj index 2ebefc2fc664..f185891ab31e 100644 --- a/contrib/clojure-package/examples/imclassification/test/imclassification/train_mnist_test.clj +++ b/contrib/clojure-package/examples/imclassification/test/imclassification/train_mnist_test.clj @@ -16,7 +16,7 @@ ;; (ns imclassification.train-mnist-test - (:require + (:require [clojure.test :refer :all] [clojure.java.io :as io] [clojure.string :as s] @@ -26,14 +26,15 @@ (defn- file-to-filtered-seq [file] (->> - file + file (io/file) (io/reader) (line-seq) (filter #(not (s/includes? % "mxnet_version"))))) (deftest mnist-two-epochs-test - (module/save-checkpoint (mnist/start [(context/cpu)] 2) {:prefix "target/test" :epoch 2}) - (is (= - (file-to-filtered-seq "test/test-symbol.json.ref") - (file-to-filtered-seq "target/test-symbol.json")))) \ No newline at end of file + (do + (mnist/start [(context/cpu)] 2) + (is (= + (file-to-filtered-seq "test/test-symbol.json.ref") + (file-to-filtered-seq "target/test-symbol.json"))))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/resource_scope.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/resource_scope.clj new file mode 100644 index 000000000000..26673485e54c --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/resource_scope.clj @@ -0,0 +1,53 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.resource-scope + (:require [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet ResourceScope))) + +(defmacro + using + "Uses a Resource Scope for all forms. This is a way to manage all Native Resources like NDArray and Symbol - it will deallocate all Native Resources by calling close on them automatically. It will not call close on Native Resources returned from the form. + Example: + (resource-scope/using + (let [temp-x (ndarray/ones [3 1]) + temp-y (ndarray/ones [3 1])] + (ndarray/+ temp-x temp-y))) " + [& forms] + `(ResourceScope/using (new ResourceScope) (util/forms->scala-fn ~@forms))) + + +(defmacro + with-do + "Alias for a do within a resource scope using. + Example: + (resource-scope/with-do + (ndarray/ones [3 1]) + :all-cleaned-up) + " + [& forms] + `(using (do ~@forms))) + +(defmacro + with-let + "Alias for a let within a resource scope using. + Example: + (resource-scope/with-let [temp-x (ndarray/ones [3 1]) + temp-y (ndarray/ones [3 1])] + (ndarray/+ temp-x temp-y))" + [& forms] + `(using (let ~@forms))) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 43970c0abd79..6b5f50792ead 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -239,3 +239,9 @@ (apply $/immutable-list)) ;; pass-through map-or-tuple-seq))) + +(defmacro forms->scala-fn + "Creates a scala fn of zero args from forms" + [& forms] + `($/fn [] + (do ~@forms))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/resource_scope_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/resource_scope_test.clj new file mode 100644 index 000000000000..77df03402629 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/resource_scope_test.clj @@ -0,0 +1,146 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.resource-scope-test + (:require [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.resource-scope :as resource-scope] + [clojure.test :refer :all])) + + +(deftest test-resource-scope-with-ndarray + (let [native-resources (atom {}) + x (ndarray/ones [2 2]) + return-val (resource-scope/using + (let [temp-x (ndarray/ones [3 1]) + temp-y (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (swap! native-resources assoc :temp-y temp-y) + (ndarray/+ temp-x 1)))] + (is (true? (ndarray/is-disposed (:temp-x @native-resources)))) + (is (true? (ndarray/is-disposed (:temp-y @native-resources)))) + (is (false? (ndarray/is-disposed return-val))) + (is (false? (ndarray/is-disposed x))) + (is (= [2.0 2.0 2.0] (ndarray/->vec return-val))))) + +(deftest test-nested-resource-scope-with-ndarray + (let [native-resources (atom {}) + x (ndarray/ones [2 2]) + return-val (resource-scope/using + (let [temp-x (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (resource-scope/using + (let [temp-y (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-y temp-y)))))] + (is (true? (ndarray/is-disposed (:temp-y @native-resources)))) + (is (true? (ndarray/is-disposed (:temp-x @native-resources)))) + (is (false? (ndarray/is-disposed x))))) + +(deftest test-resource-scope-with-sym + (let [native-resources (atom {}) + x (sym/ones [2 2]) + return-val (resource-scope/using + (let [temp-x (sym/ones [3 1]) + temp-y (sym/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (swap! native-resources assoc :temp-y temp-y) + (sym/+ temp-x 1)))] + (is (true? (sym/is-disposed (:temp-x @native-resources)))) + (is (true? (sym/is-disposed (:temp-y @native-resources)))) + (is (false? (sym/is-disposed return-val))) + (is (false? (sym/is-disposed x))))) + +(deftest test-nested-resource-scope-with-ndarray + (let [native-resources (atom {}) + x (ndarray/ones [2 2]) + return-val (resource-scope/using + (let [temp-x (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (resource-scope/using + (let [temp-y (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-y temp-y)))))] + (is (true? (ndarray/is-disposed (:temp-y @native-resources)))) + (is (true? (ndarray/is-disposed (:temp-x @native-resources)))) + (is (false? (ndarray/is-disposed x))))) + +(deftest test-nested-resource-scope-with-sym + (let [native-resources (atom {}) + x (sym/ones [2 2]) + return-val (resource-scope/using + (let [temp-x (sym/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (resource-scope/using + (let [temp-y (sym/ones [3 1])] + (swap! native-resources assoc :temp-y temp-y)))))] + (is (true? (sym/is-disposed (:temp-y @native-resources)))) + (is (true? (sym/is-disposed (:temp-x @native-resources)))) + (is (false? (sym/is-disposed x))))) + +(deftest test-list-creation-with-returning-first + (let [native-resources (atom []) + return-val (resource-scope/using + (let [temp-ndarrays (doall (repeatedly 3 #(ndarray/ones [3 1]))) + _ (reset! native-resources temp-ndarrays)] + (first temp-ndarrays)))] + (is (false? (ndarray/is-disposed return-val))) + (is (= [false true true] (mapv ndarray/is-disposed @native-resources))))) + +(deftest test-list-creation + (let [native-resources (atom []) + return-val (resource-scope/using + (let [temp-ndarrays (doall (repeatedly 3 #(ndarray/ones [3 1]))) + _ (reset! native-resources temp-ndarrays)] + (ndarray/ones [3 1])))] + (is (false? (ndarray/is-disposed return-val))) + (is (= [true true true] (mapv ndarray/is-disposed @native-resources))))) + +(deftest test-list-creation-without-let + (let [native-resources (atom []) + return-val (resource-scope/using + (first (doall (repeatedly 3 #(do + (let [x (ndarray/ones [3 1])] + (swap! native-resources conj x) + x))))))] + (is (false? (ndarray/is-disposed return-val))) + (is (= [false true true] (mapv ndarray/is-disposed @native-resources))))) + +(deftest test-with-let + (let [native-resources (atom {}) + x (ndarray/ones [2 2]) + return-val (resource-scope/with-let [temp-x (ndarray/ones [3 1]) + temp-y (ndarray/ones [3 1])] + (swap! native-resources assoc :temp-x temp-x) + (swap! native-resources assoc :temp-y temp-y) + (ndarray/+ temp-x 1))] + (is (true? (ndarray/is-disposed (:temp-x @native-resources)))) + (is (true? (ndarray/is-disposed (:temp-y @native-resources)))) + (is (false? (ndarray/is-disposed return-val))) + (is (false? (ndarray/is-disposed x))) + (is (= [2.0 2.0 2.0] (ndarray/->vec return-val))))) + +(deftest test-with-do + (let [native-resources (atom {}) + x (ndarray/ones [2 2]) + return-val (resource-scope/with-do + (swap! native-resources assoc :temp-x (ndarray/ones [3 1])) + (swap! native-resources assoc :temp-y (ndarray/ones [3 1])) + (ndarray/ones [3 1]))] + (is (true? (ndarray/is-disposed (:temp-x @native-resources)))) + (is (true? (ndarray/is-disposed (:temp-y @native-resources)))) + (is (false? (ndarray/is-disposed return-val))) + (is (false? (ndarray/is-disposed x))) + (is (= [1.0 1.0 1.0] (ndarray/->vec return-val))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj index c26f83d5aa49..4ed7d38e690a 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/util_test.clj @@ -226,3 +226,10 @@ (let [nda (util/map->scala-tuple-seq {:a-b (ndarray/ones [1 2])})] (is (= "a_b" (._1 (.head nda)))) (is (= [1.0 1.0] (ndarray/->vec (._2 (.head nda))))))) + +(deftest test-forms->scala-fn + (let [scala-fn (util/forms->scala-fn + (def x 1) + (def y 2) + {:x x :y y})] + (is (= {:x 1 :y 2} (.apply scala-fn)))))