diff --git a/contrib/clojure-package/.gitignore b/contrib/clojure-package/.gitignore index f5d81ddc7620..71d812e56ecd 100644 --- a/contrib/clojure-package/.gitignore +++ b/contrib/clojure-package/.gitignore @@ -39,6 +39,8 @@ examples/visualization/test-vis.pdf src/.DS_Store src/org/.DS_Store test/test-ndarray.clj +test/test-ndarray-api.clj test/test-symbol.clj +test/test-symbol-api.clj src/org/apache/clojure_mxnet/gen/* diff --git a/contrib/clojure-package/src/dev/generator.clj b/contrib/clojure-package/src/dev/generator.clj index ca93c3421d2a..34210bef63d0 100644 --- a/contrib/clojure-package/src/dev/generator.clj +++ b/contrib/clojure-package/src/dev/generator.clj @@ -17,10 +17,14 @@ (ns dev.generator (:require [t6.from-scala.core :as scala] + [t6.from-scala.core :refer [$ $$] :as $] [clojure.reflect :as r] - [org.apache.clojure-mxnet.util :as util] - [clojure.pprint]) - (:import (org.apache.mxnet NDArray Symbol)) + [clojure.pprint] + [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet NDArray NDArrayAPI + Symbol SymbolAPI + Base Base$RefInt Base$RefLong Base$RefFloat Base$RefString) + (scala.collection.mutable ListBuffer ArrayBuffer)) (:gen-class)) @@ -34,17 +38,17 @@ (clojure.string/replace #"\_" "-") (clojure.string/replace #"\/" "div"))) -(defn symbol-transform-param-name [parameter-types] +(defn transform-param-names [coerce-fn parameter-types] (->> parameter-types (map str) - (map (fn [x] (or (util/symbol-param-coerce x) x))) + (map (fn [x] (or (coerce-fn x) x))) (map (fn [x] (last (clojure.string/split x #"\.")))))) +(defn symbol-transform-param-name [parameter-types] + (transform-param-names util/symbol-param-coerce parameter-types)) + (defn ndarray-transform-param-name [parameter-types] - (->> parameter-types - (map str) - (map (fn [x] (or (util/ndarray-param-coerce x) x))) - (map (fn [x] (last (clojure.string/split x #"\.")))))) + (transform-param-names util/ndarray-param-coerce parameter-types)) (defn has-variadic? [params] (->> params @@ -56,37 +60,136 @@ (defn increment-param-name [pname] (if-let [num-str (re-find #"-\d" pname)] - (str (first (clojure.string/split pname #"-")) "-" (inc (Integer/parseInt (last (clojure.string/split num-str #"-"))))) + (str + (first (clojure.string/split pname #"-")) + "-" + (inc (Integer/parseInt (last (clojure.string/split num-str #"-"))))) (str pname "-" 1))) -(defn rename-duplicate-params [params] - (reduce (fn [known-names n] (conj known-names (if (contains? (set known-names) n) - (increment-param-name n) - n))) - [] - params)) - +(defn rename-duplicate-params [pnames] + (->> (reduce + (fn [pname-counts n] + (let [rn (if (pname-counts n) (str n "-" (pname-counts n)) n) + inc-pname-counts (update-in pname-counts [n] (fnil inc 0))] + (update-in inc-pname-counts [:params] conj rn))) + {:params []} + pnames) + :params)) + +(defn get-public-no-default-methods [obj] + (->> (r/reflect obj) + :members + (map #(into {} %)) + (filter #(-> % :flags :public)) + (remove #(re-find #"org\$apache\$mxnet" (str (:name %)))) + (remove #(re-find #"\$default" (str (:name %)))))) + +(defn get-public-to-gen-methods [public-to-hand-gen public-no-default] + (let [public-to-hand-gen-names + (into #{} (mapv (comp str :name) public-to-hand-gen))] + (remove #(-> % :name str public-to-hand-gen-names) public-no-default))) -;;;;;;; symbol +(defn public-by-name-and-param-count [public-reflect-info] + (->> public-reflect-info + (group-by :name) + (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)])) + (into {}))) -(def symbol-reflect-info (->> (:members (r/reflect Symbol)) - (map #(into {} %)))) +(def license + (str + ";; Licensed to the Apache Software Foundation (ASF) under one or more\n" + ";; contributor license agreements. See the NOTICE file distributed with\n" + ";; this work for additional information regarding copyright ownership.\n" + ";; The ASF licenses this file to You under the Apache License, Version 2.0\n" + ";; (the \"License\"); you may not use this file except in compliance with\n" + ";; the License. You may obtain a copy of the License at\n" + ";;\n" + ";; http://www.apache.org/licenses/LICENSE-2.0\n" + ";;\n" + ";; Unless required by applicable law or agreed to in writing, software\n" + ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n" + ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" + ";; See the License for the specific language governing permissions and\n" + ";; limitations under the License.\n" + ";;\n")) -(def symbol-public (filter (fn [x] (-> x :flags :public)) symbol-reflect-info)) +(defn write-to-file [functions ns-gen fname] + (with-open [w (clojure.java.io/writer fname)] + (.write w ns-gen) + (.write w "\n\n") + (.write w ";; Do not edit - this is auto-generated") + (.write w "\n\n") + (.write w license) + (.write w "\n\n") + (.write w "\n\n") + (doseq [f functions] + (clojure.pprint/pprint f w) + (.write w "\n")))) -(def symbol-public-no-default (->> symbol-public - (filter #(not (re-find #"org\$apache\$mxnet" (str (:name %))))) - (filter #(not (re-find #"\$default" (str (:name %))))))) +;;;;;;; Common operations + +(def libinfo (Base/_LIB)) +(def op-names + (let [l ($ ListBuffer/empty)] + (do (.mxListAllOpNames libinfo l) + (remove #(or (= "Custom" %) + (re-matches #"^_.*" %)) + (util/buffer->vec l))))) + +(defn- parse-arg-type [s] + (let [[_ var-arg-type _ set-arg-type arg-spec _ type-req _ default-val] (re-find #"(([\w-\[\]\s]+)|\{([^}]+)\})\s*(\([^)]+\))?(,\s*(optional|required)(,\s*default=(.*))?)?" s)] + {:type (clojure.string/trim (or set-arg-type var-arg-type)) + :spec arg-spec + :optional? (or (= "optional" type-req) + (= "boolean" var-arg-type)) + :default default-val + :orig s})) + +(defn- get-op-handle [op-name] + (let [ref (new Base$RefLong 0)] + (do (.nnGetOpHandle libinfo op-name ref) + (.value ref)))) + +(defn gen-op-info [op-name] + (let [handle (get-op-handle op-name) + name (new Base$RefString nil) + desc (new Base$RefString nil) + key-var-num-args (new Base$RefString nil) + num-args (new Base$RefInt 0) + arg-names ($ ListBuffer/empty) + arg-types ($ ListBuffer/empty) + arg-descs ($ ListBuffer/empty)] + (do (.mxSymbolGetAtomicSymbolInfo libinfo + handle + name + desc + num-args + arg-names + arg-types + arg-descs + key-var-num-args) + {:fn-name (clojure-case (.value name)) + :fn-description (.value desc) + :args (mapv (fn [t n d] (assoc t :name n :description d)) + (mapv parse-arg-type (util/buffer->vec arg-types)) + (mapv clojure-case (util/buffer->vec arg-names)) + (util/buffer->vec arg-descs)) + :key-var-num-args (clojure-case (.value key-var-num-args))}))) + +;;;;;;; Symbol + +(def symbol-public-no-default + (get-public-no-default-methods Symbol)) (into #{} (mapcat :parameter-types symbol-public-no-default)) - ;#{java.lang.Object scala.collection.Seq scala.Option long double scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape java.lang.String<>} +;; #{java.lang.Object scala.collection.Seq scala.Option long double scala.collection.immutable.Map int ml.dmlc.mxnet.Executor float ml.dmlc.mxnet.Context java.lang.String scala.Enumeration$Value ml.dmlc.mxnet.Symbol int<> ml.dmlc.mxnet.Symbol<> ml.dmlc.mxnet.Shape java.lang.String<>} -(def symbol-hand-gen-set #{"scala.Option" - "int org.apache.mxnet.Executor" - "scala.Enumeration$Value" - "org.apache.mxnet.Context" - "scala.Tuple2" - "scala.collection.Traversable"} ) +(def symbol-hand-gen-set + #{"scala.Option" + "scala.Enumeration$Value" + "org.apache.mxnet.Context" + "scala.Tuple2" + "scala.collection.Traversable"}) ;;; min and max have a conflicting arity of 2 with the auto gen signatures (def symbol-filter-name-set #{"max" "min"}) @@ -102,34 +205,35 @@ count pos?))) -(def symbol-public-to-hand-gen (filter is-symbol-hand-gen? symbol-public-no-default)) -(def symbol-public-to-gen (->> (remove #(contains?(->> symbol-public-to-hand-gen - (mapv :name) - (mapv str) - (set)) (str (:name %))) symbol-public-no-default))) +(def symbol-public-to-hand-gen + (filter is-symbol-hand-gen? symbol-public-no-default)) +(def symbol-public-to-gen + (get-public-to-gen-methods symbol-public-to-hand-gen + symbol-public-no-default)) (count symbol-public-to-hand-gen) ;=> 35 mostly bind! (count symbol-public-to-gen) ;=> 307 -(into #{} (map :name symbol-public-to-hand-gen));=> #{arange bind ones zeros simpleBind Variable} +(into #{} (map :name symbol-public-to-hand-gen)) +;;=> #{arange bind ones zeros simpleBind Variable} -(defn public-by-name-and-param-count [public-reflect-info] - (->> public-reflect-info - (group-by :name) - (map (fn [[k v]] [k (group-by #(count (:parameter-types %)) v)])) - (into {}))) (defn symbol-vector-args [] - `(if (map? ~'kwargs-map-or-vec-or-sym) (~'util/empty-list) (~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"}))) + `(if (map? ~'kwargs-map-or-vec-or-sym) + (~'util/empty-list) + (~'util/coerce-param ~'kwargs-map-or-vec-or-sym #{"scala.collection.Seq"}))) (defn symbol-map-args [] - `(if (map? ~'kwargs-map-or-vec-or-sym) (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym) nil)) + `(if (map? ~'kwargs-map-or-vec-or-sym) + (util/convert-symbol-map ~'kwargs-map-or-vec-or-sym) + nil)) (defn add-symbol-arities [params function-name] - (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] (mapv str params)) + (if (= ["sym-name" "kwargs-map" "symbol-list" "kwargs-map-1"] + (mapv str params)) [`([~'sym-name ~'attr-map ~'kwargs-map] (~function-name ~'sym-name (~'util/convert-symbol-map ~'attr-map) (~'util/empty-list) (~'util/convert-symbol-map ~'kwargs-map))) `([~'sym-name ~'kwargs-map-or-vec-or-sym] @@ -180,36 +284,7 @@ `(~'defn ~function-name ~@(remove nil? (gen-symbol-function-arity op-name op-values function-name)))))) -(def license - (str - ";; Licensed to the Apache Software Foundation (ASF) under one or more\n" - ";; contributor license agreements. See the NOTICE file distributed with\n" - ";; this work for additional information regarding copyright ownership.\n" - ";; The ASF licenses this file to You under the Apache License, Version 2.0\n" - ";; (the \"License\"); you may not use this file except in compliance with\n" - ";; the License. You may obtain a copy of the License at\n" - ";;\n" - ";; http://www.apache.org/licenses/LICENSE-2.0\n" - ";;\n" - ";; Unless required by applicable law or agreed to in writing, software\n" - ";; distributed under the License is distributed on an \"AS IS\" BASIS,\n" - ";; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" - ";; See the License for the specific language governing permissions and\n" - ";; limitations under the License.\n" - ";;\n")) -(defn write-to-file [functions ns-gen fname] - (with-open [w (clojure.java.io/writer fname)] - (.write w ns-gen) - (.write w "\n\n") - (.write w ";; Do not edit - this is auto-generated") - (.write w "\n\n") - (.write w license) - (.write w "\n\n") - (.write w "\n\n") - (doseq [f functions] - (clojure.pprint/pprint f w) - (.write w "\n")))) (def symbol-gen-ns "(ns org.apache.clojure-mxnet.symbol (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max @@ -223,25 +298,18 @@ (println "Generating symbol file") (write-to-file all-symbol-functions symbol-gen-ns "src/org/apache/clojure_mxnet/gen/symbol.clj")) +;;;;;;; NDArray -;;;;;;;;NDARRAY - - -(def ndarray-reflect-info (->> (:members (r/reflect NDArray)) - (map #(into {} %)))) +(def ndarray-public-no-default + (get-public-no-default-methods NDArray)) -(def ndarray-public (filter (fn [x] (-> x :flags :public)) ndarray-reflect-info)) - -(def ndarray-public-no-default (->> ndarray-public - (filter #(not (re-find #"org\$apache\$mxnet" (str (:name %))))) - (filter #(not (re-find #"\$default" (str (:name %))))))) - -(def ndarray-hand-gen-set #{"org.apache.mxnet.NDArrayFuncReturn" - "org.apache.mxnet.Context" - "scala.Enumeration$Value" - "scala.Tuple2" - "scala.collection.Traversable"} ) +(def ndarray-hand-gen-set + #{"org.apache.mxnet.NDArrayFuncReturn" + "org.apache.mxnet.Context" + "scala.Enumeration$Value" + "scala.Tuple2" + "scala.collection.Traversable"}) (defn is-ndarray-hand-gen? [info] (->> (map str (:parameter-types info)) @@ -251,17 +319,17 @@ pos?)) -(def ndarray-public-to-hand-gen (filter is-ndarray-hand-gen? ndarray-public-no-default)) -(def ndarray-public-to-gen (->> (remove #(contains?(->> ndarray-public-to-hand-gen - (mapv :name) - (mapv str) - (set)) (str (:name %))) ndarray-public-no-default))) +(def ndarray-public-to-hand-gen + (filter is-ndarray-hand-gen? ndarray-public-no-default)) +(def ndarray-public-to-gen + (get-public-to-gen-methods ndarray-public-to-hand-gen + ndarray-public-no-default)) (count ndarray-public-to-hand-gen) ;=> 15 (count ndarray-public-to-gen) ;=> 486 -(map :name ndarray-public-to-hand-gen) +(->> ndarray-public-to-hand-gen (map :name) (into #{})) @@ -294,16 +362,19 @@ ))))) +(defn gen-ndarray-functions [public-to-gen-methods] + (for [operation (sort (public-by-name-and-param-count public-to-gen-methods))] + (let [[op-name op-values] operation + function-name (-> op-name + str + scala/decode-scala-symbol + clojure-case + symbol)] + `(~'defn ~function-name + ~@(remove nil? (gen-ndarray-function-arity op-name op-values)))))) + (def all-ndarray-functions - (for [operation (sort (public-by-name-and-param-count ndarray-public-to-gen))] - (let [[op-name op-values] operation - function-name (-> op-name - str - scala/decode-scala-symbol - clojure-case - symbol)] - `(~'defn ~function-name - ~@(remove nil? (gen-ndarray-function-arity op-name op-values)))))) + (gen-ndarray-functions ndarray-public-to-gen)) (def ndarray-gen-ns "(ns org.apache.clojure-mxnet.ndarray (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max @@ -314,16 +385,191 @@ (defn generate-ndarray-file [] (println "Generating ndarray file") - (write-to-file all-ndarray-functions ndarray-gen-ns "src/org/apache/clojure_mxnet/gen/ndarray.clj")) + (write-to-file all-ndarray-functions + ndarray-gen-ns + "src/org/apache/clojure_mxnet/gen/ndarray.clj")) + +;;;;;;; SymbolAPI + +(defn symbol-api-coerce-param + [{:keys [name sym type optional?]}] + (let [coerced-param (case type + "Shape" `(when ~sym (~'mx-shape/->shape ~sym)) + "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym) + "Map[String, String]" + `(when ~sym + (->> ~sym + (mapv (fn [[~'k ~'v]] [~'k (str ~'v)])) + (into {}) + ~'util/convert-map)) + sym) + nil-param-allowed? (#{"name" "attr"} name)] + (if (and optional? (not nil-param-allowed?)) + `(~'util/->option ~coerced-param) + coerced-param))) + +(defn gen-symbol-api-doc [fn-description params] + (let [param-descriptions (mapv (fn [{:keys [name description optional?]}] + (str "`" name "`: " + description + (when optional? " (optional)") + "\n")) + params)] + (str fn-description "\n\n" + (apply str param-descriptions)))) + +(defn gen-symbol-api-default-arity [op-name params] + (let [opt-params (filter :optional? params) + coerced-params (mapv symbol-api-coerce-param params) + default-args (array-map :keys (mapv :sym params) + :or (into {} + (mapv (fn [{:keys [sym]}] [sym nil]) + opt-params)) + :as 'opts)] + `([~default-args] + (~'util/coerce-return + (~(symbol (str "SymbolAPI/" op-name)) + ~@coerced-params))))) + +(defn gen-symbol-api-function [op-name] + (let [{:keys [fn-name fn-description args]} (gen-op-info op-name) + params (mapv (fn [{:keys [name type optional?] :as opts}] + (assoc opts + :sym (symbol name) + :optional? (or optional? + (= "NDArray-or-Symbol" type)))) + (conj args + {:name "name" + :type "String" + :optional? true + :description "Name of the symbol"} + {:name "attr" + :type "Map[String, String]" + :optional? true + :description "Attributes of the symbol"})) + doc (gen-symbol-api-doc fn-description params) + default-call (gen-symbol-api-default-arity op-name params)] + `(~'defn ~(symbol fn-name) + ~doc + ~@default-call))) + +(def all-symbol-api-functions + (mapv gen-symbol-api-function op-names)) + +(def symbol-api-gen-ns "(ns + ^{:doc \"Experimental\"} + org.apache.clojure-mxnet.symbol-api + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle ref]) + (:require [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.shape :as mx-shape]) + (:import (org.apache.mxnet SymbolAPI)))") + +(defn generate-symbol-api-file [] + (println "Generating symbol-api file") + (write-to-file all-symbol-api-functions symbol-api-gen-ns "src/org/apache/clojure_mxnet/gen/symbol_api.clj")) + +;;;;;;; NDArrayAPI + +(defn ndarray-api-coerce-param + [{:keys [sym type optional?]}] + (let [coerced-param (case type + "Shape" `(when ~sym (~'mx-shape/->shape ~sym)) + "NDArray-or-Symbol[]" `(~'clojure.core/into-array ~sym) + sym)] + (if optional? + `(~'util/->option ~coerced-param) + coerced-param))) + +(defn gen-ndarray-api-doc [fn-description params] + (let [param-descriptions (mapv (fn [{:keys [name description optional?]}] + (str "`" name "`: " + description + (when optional? " (optional)") + "\n")) + params)] + (str fn-description "\n\n" + (apply str param-descriptions)))) + +(defn gen-ndarray-api-default-arity [op-name params] + (let [opt-params (filter :optional? params) + coerced-params (mapv ndarray-api-coerce-param params) + default-args (array-map :keys (mapv :sym params) + :or (into {} + (mapv (fn [{:keys [sym]}] [sym nil]) + opt-params)) + :as 'opts)] + `([~default-args] + (~'util/coerce-return + (~(symbol (str "NDArrayAPI/" op-name)) + ~@coerced-params))))) + +(defn gen-ndarray-api-required-arity [fn-name req-params] + (let [req-args (->> req-params + (mapv (fn [{:keys [sym]}] [(keyword sym) sym])) + (into {}))] + `(~(mapv :sym req-params) + (~(symbol fn-name) ~req-args)))) + +(defn gen-ndarray-api-function [op-name] + (let [{:keys [fn-name fn-description args]} (gen-op-info op-name) + params (mapv (fn [{:keys [name] :as opts}] + (assoc opts :sym (symbol name))) + (conj args {:name "out" + :type "NDArray-or-Symbol" + :optional? true + :description "Output array."})) + doc (gen-ndarray-api-doc fn-description params) + opt-params (filter :optional? params) + req-params (remove :optional? params) + req-call (gen-ndarray-api-required-arity fn-name req-params) + default-call (gen-ndarray-api-default-arity op-name params)] + (if (= 1 (count req-params)) + `(~'defn ~(symbol fn-name) + ~doc + ~@default-call) + `(~'defn ~(symbol fn-name) + ~doc + ~req-call + ~default-call)))) + +(def all-ndarray-api-functions + (mapv gen-ndarray-api-function op-names)) + +(def ndarray-api-gen-ns "(ns + ^{:doc \"Experimental\"} + org.apache.clojure-mxnet.ndarray-api + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle + ref]) + (:require [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet NDArrayAPI)))") + + +(defn generate-ndarray-api-file [] + (println "Generating ndarray-api file") + (write-to-file all-ndarray-api-functions + ndarray-api-gen-ns + "src/org/apache/clojure_mxnet/gen/ndarray_api.clj")) ;;; autogen the files (do (generate-ndarray-file) - (generate-symbol-file)) + (generate-ndarray-api-file) + (generate-symbol-file) + (generate-symbol-api-file)) (comment + (gen-op-info "ElementWiseSum") + + (gen-ndarray-api-function "Activation") + + (gen-symbol-api-function "Activation") + ;; This generates a file with the bulk of the nd-array functions (generate-ndarray-file) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj new file mode 100644 index 000000000000..70359a6ef9b7 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray_api.clj @@ -0,0 +1,32 @@ +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.ndarray-api + "Experimental NDArray API" + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle + ref]) + + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as mx-context] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [clojure.reflect :as r] + [t6.from-scala.core :refer [$] :as $]) + (:import (org.apache.mxnet NDArrayAPI))) + +;; loads the generated functions into the namespace +(do (clojure.core/load "gen/ndarray_api")) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj new file mode 100644 index 000000000000..69cc8136d500 --- /dev/null +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/symbol_api.clj @@ -0,0 +1,32 @@ +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.symbol-api + "Experimental Symbol API" + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle ref]) + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as mx-context] + [org.apache.clojure-mxnet.executor :as ex] + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util] + [t6.from-scala.core :refer [$] :as $] + [org.apache.clojure-mxnet.ndarray :as ndarray]) + (:import (org.apache.mxnet SymbolAPI))) + +;; loads the generated functions into the namespace +(do (clojure.core/load "gen/symbol_api")) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj index 89ac1cd66a57..7ee25d4dd25e 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/util.clj @@ -35,7 +35,6 @@ "int<>" "vec-of-ints" "float<>" "vec-of-floats" "byte<>" "byte-array" - "java.lang.String<>" "vec-or-strings" "org.apache.mxnet.NDArray" "ndarray" "org.apache.mxnet.Symbol" "sym" "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE" "double-or-float"}) @@ -49,7 +48,7 @@ "int<>" "vec-of-ints" "float<>" "vec-of-floats" "byte<>" "byte-array" - "java.lang.String<>" "vec-or-strings" + "java.lang.String<>" "vec-of-strings" "org.apache.mxnet.Symbol" "sym" "java.lang.Object" "object"}) @@ -152,9 +151,12 @@ (and (get targets "scala.collection.Seq") (instance? org.apache.mxnet.Symbol param)) ($/immutable-list param) (and (get targets "scala.collection.Seq") (and (or (vector? param) (seq? param)) (empty? param))) (empty-list) (and (get targets "scala.collection.Seq") (or (vector? param) (seq? param))) (apply $/immutable-list param) + (and (get targets "org.apache.mxnet.Shape") (or (vector? param) (seq? param) (empty? param))) (mx-shape/->shape param) (and (get targets "int<>") (vector? param)) (int-array param) (and (get targets "float<>") (vector? param)) (float-array param) (and (get targets "java.lang.String<>") (vector? param)) (into-array param) + (and (get targets "org.apache.mxnet.NDArray<>") (vector? param)) (into-array param) + (and (get targets "org.apache.mxnet.Symbol<>") (vector? param)) (into-array param) (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (instance? Float param)) (primitives/mx-float param) (and (get targets "org.apache.mxnet.MX_PRIMITIVES$MX_PRIMITIVE_TYPE") (number? param)) (primitives/mx-double param) :else param)) diff --git a/contrib/clojure-package/test/dev/generator_test.clj b/contrib/clojure-package/test/dev/generator_test.clj index 05b4a741bc7c..cf28241c59e8 100644 --- a/contrib/clojure-package/test/dev/generator_test.clj +++ b/contrib/clojure-package/test/dev/generator_test.clj @@ -50,6 +50,127 @@ (is (= transformed-params (gen/symbol-transform-param-name (:parameter-types (symbol-reflect-info "floor"))))))) +(deftest test-gen-op-info + (testing "activation" + (let [activation-info (gen/gen-op-info "Activation")] + (is (= "activation" (:fn-name activation-info))) + (is (string? (:fn-description activation-info))) + (is (= 2 (-> activation-info :args count))) + (is (= "" (:key-var-num-args activation-info))) + + (is (= "data" (-> activation-info :args first :name))) + (is (= "NDArray-or-Symbol" (-> activation-info :args first :type))) + (is (false? (-> activation-info :args first :optional?))) + (is (nil? (-> activation-info :args first :default))) + (is (string? (-> activation-info :args first :description))) + + (is (= "act-type" (-> activation-info :args second :name))) + (is (= "'relu', 'sigmoid', 'softrelu', 'softsign', 'tanh'" (-> activation-info :args second :type))) + (is (false? (-> activation-info :args second :optional?))) + (is (nil? (-> activation-info :args second :default))) + (is (string? (-> activation-info :args second :description))))) + + (testing "argmin" + (let [argmin-info (gen/gen-op-info "argmin")] + (is (= "argmin" (:fn-name argmin-info))) + (is (= 3 (-> argmin-info :args count))) + + (is (= "data" (-> argmin-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol" (-> argmin-info :args (nth 0) :type))) + (is (false? (-> argmin-info :args (nth 0) :optional?))) + + (is (= "axis" (-> argmin-info :args (nth 1) :name))) + (is (= "int or None" (-> argmin-info :args (nth 1) :type))) + (is (= "'None'" (-> argmin-info :args (nth 1) :default))) + (is (true? (-> argmin-info :args (nth 1) :optional?))) + + (is (= "keepdims" (-> argmin-info :args (nth 2) :name))) + (is (= "boolean" (-> argmin-info :args (nth 2) :type))) + (is (= "0" (-> argmin-info :args (nth 2) :default))) + (is (true? (-> argmin-info :args (nth 2) :optional?))))) + + (testing "concat" + (let [concat-info (gen/gen-op-info "Concat")] + (is (= "concat" (:fn-name concat-info))) + (is (= 3 (-> concat-info :args count))) + (is (= "num-args" (:key-var-num-args concat-info))) + + (is (= "data" (-> concat-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol[]" (-> concat-info :args (nth 0) :type))) + (is (false? (-> concat-info :args (nth 0) :optional?))) + + (is (= "num-args" (-> concat-info :args (nth 1) :name))) + (is (= "int" (-> concat-info :args (nth 1) :type))) + (is (false? (-> concat-info :args (nth 1) :optional?))) + + (is (= "dim" (-> concat-info :args (nth 2) :name))) + (is (= "int" (-> concat-info :args (nth 2) :type))) + (is (= "'1'" (-> concat-info :args (nth 2) :default))) + (is (true? (-> concat-info :args (nth 2) :optional?))))) + + (testing "convolution" + (let [convolution-info (gen/gen-op-info "Convolution")] + + (is (= "convolution" (:fn-name convolution-info))) + (is (= 14 (-> convolution-info :args count))) + (is (= "" (:key-var-num-args convolution-info))) + + (is (= "data" (-> convolution-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 0) :type))) + (is (false? (-> convolution-info :args (nth 0) :optional?))) + + (is (= "weight" (-> convolution-info :args (nth 1) :name))) + (is (= "NDArray-or-Symbol" (-> convolution-info :args (nth 1) :type))) + (is (false? (-> convolution-info :args (nth 1) :optional?))) + + (is (= "kernel" (-> convolution-info :args (nth 3) :name))) + (is (= "Shape" (-> convolution-info :args (nth 3) :type))) + (is (= "(tuple)" (-> convolution-info :args (nth 3) :spec))) + (is (false? (-> convolution-info :args (nth 3) :optional?))) + + (is (= "stride" (-> convolution-info :args (nth 4) :name))) + (is (= "Shape" (-> convolution-info :args (nth 4) :type))) + (is (= "(tuple)" (-> convolution-info :args (nth 4) :spec))) + (is (= "[]" (-> convolution-info :args (nth 4) :default))) + (is (true? (-> convolution-info :args (nth 4) :optional?))) + + (is (= "num-filter" (-> convolution-info :args (nth 7) :name))) + (is (= "int" (-> convolution-info :args (nth 7) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 7) :spec))) + (is (false? (-> convolution-info :args (nth 7) :optional?))) + + (is (= "num-group" (-> convolution-info :args (nth 8) :name))) + (is (= "int" (-> convolution-info :args (nth 8) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 8) :spec))) + (is (= "1" (-> convolution-info :args (nth 8) :default))) + (is (true? (-> convolution-info :args (nth 8) :optional?))) + + (is (= "workspace" (-> convolution-info :args (nth 9) :name))) + (is (= "long" (-> convolution-info :args (nth 9) :type))) + (is (= "(non-negative)" (-> convolution-info :args (nth 9) :spec))) + (is (= "1024" (-> convolution-info :args (nth 9) :default))) + (is (true? (-> convolution-info :args (nth 9) :optional?))) + + (is (= "no-bias" (-> convolution-info :args (nth 10) :name))) + (is (= "boolean" (-> convolution-info :args (nth 10) :type))) + (is (= "0" (-> convolution-info :args (nth 10) :default))) + (is (true? (-> convolution-info :args (nth 10) :optional?))) + + (is (= "layout" (-> convolution-info :args (nth 13) :name))) + (is (= "None, 'NCDHW', 'NCHW', 'NCW', 'NDHWC', 'NHWC'" (-> convolution-info :args (nth 13) :type))) + (is (= "'None'" (-> convolution-info :args (nth 13) :default))) + (is (true? (-> convolution-info :args (nth 13) :optional?))))) + + (testing "element wise sum" + (let [element-wise-sum-info (gen/gen-op-info "ElementWiseSum")] + (is (= "add-n" (:fn-name element-wise-sum-info))) + (is (= 1 (-> element-wise-sum-info :args count))) + (is (= "num-args" (:key-var-num-args element-wise-sum-info))) + + (is (= "args" (-> element-wise-sum-info :args (nth 0) :name))) + (is (= "NDArray-or-Symbol[]" (-> element-wise-sum-info :args (nth 0) :type))) + (is (false? (-> element-wise-sum-info :args (nth 0) :optional?)))))) + (deftest test-ndarray-transform-param-name (let [params ["scala.collection.immutable.Map" "scala.collection.Seq"] @@ -68,7 +189,10 @@ (deftest test-rename-duplicate-params (is (= ["foo" "bar" "baz"] (gen/rename-duplicate-params ["foo" "bar" "baz"]))) - (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" "bar"])))) + (is (= ["foo" "bar" "bar-1"] (gen/rename-duplicate-params ["foo" "bar" "bar"]))) + (is (= ["foo" "bar" "bar-1" "foo-1"] (gen/rename-duplicate-params ["foo" "bar" "bar" "foo"]))) + (is (= ["foo" "bar" "bar-1" "bar-2"] (gen/rename-duplicate-params ["foo" "bar" "bar" "bar"]))) + (is (= ["foo" "bar" "bar-1" "bar-2" "foo-1" "baz"] (gen/rename-duplicate-params ["foo" "bar" "bar" "bar" "foo" "baz"])))) (deftest test-is-symbol-hand-gen? (is (not (false? (gen/is-symbol-hand-gen? (symbol-reflect-info "max"))))) @@ -191,7 +315,17 @@ (gen/gen-ndarray-function-arity op-name op-values))))) (deftest test-write-to-file - (testing "symbol" + (testing "symbol-api" + (let [fname "test/test-symbol-api.clj" + _ (gen/write-to-file [(first gen/all-symbol-api-functions) + (second gen/all-symbol-api-functions)] + gen/symbol-api-gen-ns + fname) + good-contents (slurp "test/good-test-symbol-api.clj") + contents (slurp fname)] + (is (= good-contents contents)))) + + (testing "symbol" (let [fname "test/test-symbol.clj" _ (gen/write-to-file [(first gen/all-symbol-functions)] gen/symbol-gen-ns @@ -200,6 +334,16 @@ contents (slurp fname)] (is (= good-contents contents)))) + (testing "ndarray-api" + (let [fname "test/test-ndarray-api.clj" + _ (gen/write-to-file [(first gen/all-ndarray-api-functions) + (second gen/all-ndarray-api-functions)] + gen/ndarray-api-gen-ns + fname) + good-contents (slurp "test/good-test-ndarray-api.clj") + contents (slurp fname)] + (is (= good-contents contents)))) + (testing "ndarray" (let [fname "test/test-ndarray.clj" _ (gen/write-to-file [(first gen/all-ndarray-functions)] diff --git a/contrib/clojure-package/test/good-test-ndarray-api.clj b/contrib/clojure-package/test/good-test-ndarray-api.clj new file mode 100644 index 000000000000..1b83a7beb7bc --- /dev/null +++ b/contrib/clojure-package/test/good-test-ndarray-api.clj @@ -0,0 +1,89 @@ +(ns + ^{:doc "Experimental"} + org.apache.clojure-mxnet.ndarray-api + (:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max + min repeat reverse set sort take to-array empty shuffle + ref]) + (:require [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet NDArrayAPI))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + activation + "Applies an activation function element-wise to the input.\n\nThe following activation functions are supported:\n\n- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n- `tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n- `softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in src/operator/nn/activation.cc:L167\n\n`data`: The input array.\n`act-type`: Activation function to be applied.\n`out`: Output array. (optional)\n" + ([data act-type] (activation {:data data, :act-type act-type})) + ([{:keys [data act-type out], :or {out nil}, :as opts}] + (util/coerce-return + (NDArrayAPI/Activation data act-type (util/->option out))))) + +(defn + batch-norm + "Batch normalization.\n\nNormalizes a data batch by mean and variance, and applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has more than one dimension and we normalize along axis 1.\nWe first compute the mean and variance along this axis:\n\n.. math::\n\n data\\_mean[i] = mean(data[:,i,:,...]) \\\\\n data\\_var[i] = var(data[:,i,:,...])\n\nThen compute the normalized output, which has the same shape as input, as following:\n\n.. math::\n\n out[:,i,:,...] = \\frac{data[:,i,:,...] - data\\_mean[i]}{\\sqrt{data\\_var[i]+\\epsilon}} * gamma[i] + beta[i]\n\nBoth *mean* and *var* returns a scalar by treating the input as a vector.\n\nAssume the input has size *k* on axis 1, then both ``gamma`` and ``beta``\nhave shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and\nthe inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these\ntwo outputs are blocked.\n\nBesides the inputs and the outputs, this operator accepts two auxiliary\nstates, ``moving_mean`` and ``moving_var``, which are *k*-length\nvectors. They are global statistics for the whole dataset, which are updated\nby::\n\n moving_mean = moving_mean * momentum + data_mean * (1 - momentum)\n moving_var = moving_var * momentum + data_var * (1 - momentum)\n\nIf ``use_global_stats`` is set to be true, then ``moving_mean`` and\n``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute\nthe output. It is often used during inference.\n\nThe parameter ``axis`` specifies which axis of the input shape denotes\nthe 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel\naxis to be the last item in the input shape.\n\nBoth ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,\nthen set ``gamma`` to 1 and its gradient to 0.\n\n.. Note::\n When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False,\n the sparse tensors will fallback.\n\n\n\nDefined in src/operator/nn/batch_norm.cc:L574\n\n`data`: Input data to batch normalization\n`gamma`: gamma array\n`beta`: beta array\n`moving-mean`: running mean of input\n`moving-var`: running variance of input\n`eps`: Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) (optional)\n`momentum`: Momentum for moving average (optional)\n`fix-gamma`: Fix gamma while training (optional)\n`use-global-stats`: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. (optional)\n`output-mean-var`: Output the mean and inverse std (optional)\n`axis`: Specify which shape axis the channel is specified (optional)\n`cudnn-off`: Do not select CUDNN operator, if available (optional)\n`out`: Output array. (optional)\n" + ([data gamma beta moving-mean moving-var] + (batch-norm + {:data data, + :gamma gamma, + :beta beta, + :moving-mean moving-mean, + :moving-var moving-var})) + ([{:keys + [data + gamma + beta + moving-mean + moving-var + eps + momentum + fix-gamma + use-global-stats + output-mean-var + axis + cudnn-off + out], + :or + {eps nil, + momentum nil, + fix-gamma nil, + use-global-stats nil, + output-mean-var nil, + axis nil, + cudnn-off nil, + out nil}, + :as opts}] + (util/coerce-return + (NDArrayAPI/BatchNorm + data + gamma + beta + moving-mean + moving-var + (util/->option eps) + (util/->option momentum) + (util/->option fix-gamma) + (util/->option use-global-stats) + (util/->option output-mean-var) + (util/->option axis) + (util/->option cudnn-off) + (util/->option out))))) + diff --git a/contrib/clojure-package/test/good-test-symbol-api.clj b/contrib/clojure-package/test/good-test-symbol-api.clj new file mode 100644 index 000000000000..a03088486ee8 --- /dev/null +++ b/contrib/clojure-package/test/good-test-symbol-api.clj @@ -0,0 +1,109 @@ +(ns + ^{:doc "Experimental"} + org.apache.clojure-mxnet.symbol-api + (:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max + min repeat reverse set sort take to-array empty sin + get apply shuffle ref]) + (:require [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.shape :as mx-shape]) + (:import (org.apache.mxnet SymbolAPI))) + +;; Do not edit - this is auto-generated + +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + + + + +(defn + activation + "Applies an activation function element-wise to the input.\n\nThe following activation functions are supported:\n\n- `relu`: Rectified Linear Unit, :math:`y = max(x, 0)`\n- `sigmoid`: :math:`y = \\frac{1}{1 + exp(-x)}`\n- `tanh`: Hyperbolic tangent, :math:`y = \\frac{exp(x) - exp(-x)}{exp(x) + exp(-x)}`\n- `softrelu`: Soft ReLU, or SoftPlus, :math:`y = log(1 + exp(x))`\n- `softsign`: :math:`y = \\frac{x}{1 + abs(x)}`\n\n\n\nDefined in src/operator/nn/activation.cc:L167\n\n`data`: The input array. (optional)\n`act-type`: Activation function to be applied.\n`name`: Name of the symbol (optional)\n`attr`: Attributes of the symbol (optional)\n" + [{:keys [data act-type name attr], + :or {data nil, name nil, attr nil}, + :as opts}] + (util/coerce-return + (SymbolAPI/Activation + (util/->option data) + act-type + name + (clojure.core/when + attr + (clojure.core/->> + attr + (clojure.core/mapv + (clojure.core/fn [[k v]] [k (clojure.core/str v)])) + (clojure.core/into {}) + util/convert-map))))) + +(defn + batch-norm + "Batch normalization.\n\nNormalizes a data batch by mean and variance, and applies a scale ``gamma`` as\nwell as offset ``beta``.\n\nAssume the input has more than one dimension and we normalize along axis 1.\nWe first compute the mean and variance along this axis:\n\n.. math::\n\n data\\_mean[i] = mean(data[:,i,:,...]) \\\\\n data\\_var[i] = var(data[:,i,:,...])\n\nThen compute the normalized output, which has the same shape as input, as following:\n\n.. math::\n\n out[:,i,:,...] = \\frac{data[:,i,:,...] - data\\_mean[i]}{\\sqrt{data\\_var[i]+\\epsilon}} * gamma[i] + beta[i]\n\nBoth *mean* and *var* returns a scalar by treating the input as a vector.\n\nAssume the input has size *k* on axis 1, then both ``gamma`` and ``beta``\nhave shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and\nthe inverse of ``data_var``, which are needed for the backward pass. Note that gradient of these\ntwo outputs are blocked.\n\nBesides the inputs and the outputs, this operator accepts two auxiliary\nstates, ``moving_mean`` and ``moving_var``, which are *k*-length\nvectors. They are global statistics for the whole dataset, which are updated\nby::\n\n moving_mean = moving_mean * momentum + data_mean * (1 - momentum)\n moving_var = moving_var * momentum + data_var * (1 - momentum)\n\nIf ``use_global_stats`` is set to be true, then ``moving_mean`` and\n``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute\nthe output. It is often used during inference.\n\nThe parameter ``axis`` specifies which axis of the input shape denotes\nthe 'channel' (separately normalized groups). The default is 1. Specifying -1 sets the channel\naxis to be the last item in the input shape.\n\nBoth ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,\nthen set ``gamma`` to 1 and its gradient to 0.\n\n.. Note::\n When ``fix_gamma`` is set to True, no sparse support is provided. If ``fix_gamma is`` set to False,\n the sparse tensors will fallback.\n\n\n\nDefined in src/operator/nn/batch_norm.cc:L574\n\n`data`: Input data to batch normalization (optional)\n`gamma`: gamma array (optional)\n`beta`: beta array (optional)\n`moving-mean`: running mean of input (optional)\n`moving-var`: running variance of input (optional)\n`eps`: Epsilon to prevent div 0. Must be no less than CUDNN_BN_MIN_EPSILON defined in cudnn.h when using cudnn (usually 1e-5) (optional)\n`momentum`: Momentum for moving average (optional)\n`fix-gamma`: Fix gamma while training (optional)\n`use-global-stats`: Whether use global moving statistics instead of local batch-norm. This will force change batch-norm into a scale shift operator. (optional)\n`output-mean-var`: Output the mean and inverse std (optional)\n`axis`: Specify which shape axis the channel is specified (optional)\n`cudnn-off`: Do not select CUDNN operator, if available (optional)\n`name`: Name of the symbol (optional)\n`attr`: Attributes of the symbol (optional)\n" + [{:keys + [data + gamma + beta + moving-mean + moving-var + eps + momentum + fix-gamma + use-global-stats + output-mean-var + axis + cudnn-off + name + attr], + :or + {output-mean-var nil, + axis nil, + cudnn-off nil, + fix-gamma nil, + eps nil, + data nil, + attr nil, + beta nil, + name nil, + use-global-stats nil, + moving-mean nil, + moving-var nil, + momentum nil, + gamma nil}, + :as opts}] + (util/coerce-return + (SymbolAPI/BatchNorm + (util/->option data) + (util/->option gamma) + (util/->option beta) + (util/->option moving-mean) + (util/->option moving-var) + (util/->option eps) + (util/->option momentum) + (util/->option fix-gamma) + (util/->option use-global-stats) + (util/->option output-mean-var) + (util/->option axis) + (util/->option cudnn-off) + name + (clojure.core/when + attr + (clojure.core/->> + attr + (clojure.core/mapv + (clojure.core/fn [[k v]] [k (clojure.core/str v)])) + (clojure.core/into {}) + util/convert-map))))) + diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj index feda45b9d027..ca9d4bc93986 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/conv_test.clj @@ -24,6 +24,8 @@ [org.apache.clojure-mxnet.module :as m] [org.apache.clojure-mxnet.optimizer :as optimizer] [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.symbol-api :as sym-api] + [org.apache.clojure-mxnet.util :as util] [clojure.reflect :as r])) (def data-dir "data/") @@ -54,17 +56,19 @@ (defn get-symbol [] (as-> (sym/variable "data") data - (sym/convolution "conv1" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) - (sym/batch-norm "bn1" {:data data}) - (sym/activation "relu1" {:data data :act-type "relu"}) - (sym/pooling "mp1" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) (sym/convolution "conv2" {:data data :kernel [3 3] :num-filter 32 :stride [2 2]}) - (sym/batch-norm "bn2" {:data data}) - (sym/activation "relu2" {:data data :act-type "relu"}) - (sym/pooling "mp2" {:data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + (sym-api/convolution {:name "conv1" :data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + (sym-api/batch-norm {:name "bn1" :data data}) + (sym-api/activation {:name "relu1" :data data :act-type "relu"}) + (sym-api/pooling {:name "mp1" :data data :kernel [2 2] :pool-type "max" :stride [2 2]}) - (sym/flatten "fl" {:data data}) - (sym/fully-connected "fc2" {:data data :num-hidden 10}) - (sym/softmax-output "softmax" {:data data}))) + (sym-api/convolution {:name "conv2" :data data :kernel [3 3] :num-filter 32 :stride [2 2]}) + (sym-api/batch-norm {:name "bn2" :data data}) + (sym-api/activation {:name "relu2" :data data :act-type "relu"}) + (sym-api/pooling {:name "mp2" :data data :kernel [2 2] :pool-type "max" :stride [2 2]}) + + (sym-api/flatten {:name "fl" :data data}) + (sym-api/fully-connected {:name "fc2" :data data :num-hidden 10}) + (sym-api/softmax-output {:name "softmax" :data data}))) (deftest test-conv [] (let [mod (m/module (get-symbol))] diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj new file mode 100644 index 000000000000..18b8b78f19d1 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/ndarray_api_test.clj @@ -0,0 +1,415 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.ndarray-api-test + (:require [org.apache.clojure-mxnet.base :as base] + [org.apache.clojure-mxnet.context :as ctx] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.ndarray :as ndarray :refer [->vec zeros ones += -= *= full shape shape-vec]] + [org.apache.clojure-mxnet.ndarray-api :as ndarray-api] + [org.apache.clojure-mxnet.shape :as mx-shape :refer [->shape]] + [org.apache.clojure-mxnet.test-util :as test-util :refer [approx=]] + [org.apache.clojure-mxnet.util :as util :refer [->option]] + [clojure.test :refer :all])) + +(deftest test-activation + (let [data (ndarray/array [2 1 0 -1 -2] [1 5]) + relu (ndarray-api/activation data "relu") + sigmoid (ndarray-api/activation data "sigmoid") + softsign (ndarray-api/activation data "softsign") + out (ndarray/zeros [1 5]) + _ (ndarray-api/activation {:data data :act-type "relu" :out out})] + (is (= [2.0 1.0 0.0 0.0 0.0] (->vec relu))) + (is (approx= 1e-3 [0.881 0.731 0.5 0.269 0.119] (->vec sigmoid))) + (is (approx= 1e-3 [0.666 0.5 0.0 -0.5 -0.666] (->vec softsign))) + (is (= [2.0 1.0 0.0 0.0 0.0] (->vec out))))) + +(deftest test-bilinear-sampler + (let [data (ndarray/array [1 4 3 6 + 1 8 8 9 + 0 4 1 5 + 1 0 1 3] + [1 1 4 4]) + affine (ndarray/array [2 0 0 + 0 2 0] + [1 6]) + grid (ndarray-api/grid-generator {:data affine :transform-type "affine" :target-shape [4 4]}) + out (ndarray-api/bilinear-sampler data grid)] + (is (approx= 1e-3 + [0.0 0.0 0.0 0.0 + 0.0 3.5 6.5 0.0 + 0.0 1.25 2.5 0.0 + 0.0 0.0 0.0 0.0] + (->vec out))))) + +(deftest test-cast + (let [nda1 (ndarray/array [0.9 1.3] [2]) + nda2 (ndarray/array [1e20 11.1] [2]) + nda3 (ndarray/array [300 11.1 10.9 -1 -3] [5]) + out (ndarray/zeros [2] {:dtype dtype/INT32}) + _ (ndarray-api/cast {:data nda1 :dtype (str dtype/INT32) :out out})] + (is (= [0.0 1.0] (->vec (ndarray-api/cast nda1 (str dtype/INT32))))) + (is (= [(float 1e20) (float 11.1)] (->vec (ndarray-api/cast nda2 (str dtype/FLOAT32))))) + ;; uint8 gets converted to native types after ->vec + (is (= [44.0 11.0 10.0 -1.0 -3.0] (->vec (ndarray-api/cast nda3 "uint8")))))) + +(deftest test-concat + (let [nda1 (ndarray/zeros [1 2]) + nda2 (ndarray/ones [1 2]) + out (ndarray/zeros [1 4]) + res1 (ndarray-api/concat [nda1 nda2] 2) ;; num_args=2, dim=1 (default) + res2 (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 0}) ;; num_args=2, dim=0 + res3 (ndarray-api/concat {:data [nda1 nda2 nda1] :num-args 3 :dim 1}) ;; num_args=3, dim=1 + _ (ndarray-api/concat {:data [nda1 nda2] :num-args 2 :dim 1 :out out}) ;; store result in out + ] + (is (= [0.0 0.0 1.0 1.0] (->vec res1))) + (is (= [1 4] (shape-vec res1))) + (is (= [0.0 0.0 1.0 1.0] (->vec res2))) + (is (= [2 2] (shape-vec res2))) + (is (= [0.0 0.0 1.0 1.0 0.0 0.0] (->vec res3))) + (is (= [1 6] (shape-vec res3))) + (is (= [0.0 0.0 1.0 1.0] (->vec out))) + (is (= [1 4] (shape-vec out))))) + +(deftest test-embedding + (let [input-dim 4 + output-dim 5 + w (ndarray/array [0. 1. 2. 3. 4. + 5. 6. 7. 8. 9. + 10. 11. 12. 13. 14. + 15. 16. 17. 18. 19.] + [4 5]) + x (ndarray/array [1. 3. + 0. 2.] + [2 2]) + out (ndarray-api/embedding x w input-dim output-dim)] + (is (= [5. 6. 7. 8. 9. + 15. 16. 17. 18. 19. + 0. 1. 2. 3. 4. + 10. 11. 12. 13. 14.] + (->vec out))) + (is (= [2 2 5] (shape-vec out))))) + +(deftest test-flatten + (let [nda (ndarray/array [1 2 3 + 4 5 6 + 7 8 9 + 1 2 3 + 4 5 6 + 7 8 9] + [2 3 3]) + out (ndarray/zeros [2 9]) + res (ndarray-api/flatten {:data nda}) + _ (ndarray-api/flatten {:data nda :out out})] + (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9. + 1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec res))) + (is (= [2 9] (shape-vec res))) + (is (= [1. 2. 3. 4. 5. 6. 7. 8. 9. + 1. 2. 3. 4. 5. 6. 7. 8. 9.] (->vec out))) + (is (= [2 9] (shape-vec out))))) + +(deftest test-instance-norm + (let [x (ndarray/array [1.1 2.2 3.3 4.4] [2 1 2]) + gamma (ndarray/array [1.5] [1]) + beta (ndarray/array [0.5] [1]) + res (ndarray-api/instance-norm x gamma beta)] + (is (approx= 1e-4 [-0.9975 1.9975 + -0.9975 1.9975] (->vec res))) + (is (= [2 1 2] (shape-vec res))))) + +(deftest test-l2-normalization + (let [x (ndarray/array [1 2 3 4 2 2 5 6] [2 2 2]) + res1 (ndarray-api/l2-normalization {:data x}) ;; instance-wise + res2 (ndarray-api/l2-normalization {:data x :mode "instance"}) + res3 (ndarray-api/l2-normalization {:data x :mode "channel"}) + res4 (ndarray-api/l2-normalization {:data x :mode "spatial"})] + (is (approx= 1e-4 [0.1825 0.3651 + 0.5477 0.7303 + 0.2407 0.2407 + 0.6019 0.7223] (->vec res1))) + (is (approx= 1e-4 [0.1825 0.3651 + 0.5477 0.7303 + 0.2407 0.2407 + 0.6019 0.7223] (->vec res2))) + (is (approx= 1e-4 [0.3162 0.4472 + 0.9486 0.8944 + 0.3714 0.3162 + 0.9284 0.9486] (->vec res3))) + (is (approx= 1e-4 [0.4472 0.8944 + 0.6 0.8 + 0.7071 0.7071 + 0.6402 0.7682] (->vec res4))))) + +(deftest test-pad + (let [x (ndarray/array [1 2 3 + 4 5 6 + 7 8 9 + 10 11 12 + 11 12 13 + 14 15 16 + 17 18 19 + 20 21 22] + [2 2 2 3]) + res1 (ndarray-api/pad x "edge" [0,0,0,0,1,1,1,1]) + res2 (ndarray-api/pad {:data x :mode "constant" :pad-width [0,0,0,0,1,1,1,1] :constant-value 0})] + (is (= [1. 1. 2. 3. 3. + 1. 1. 2. 3. 3. + 4. 4. 5. 6. 6. + 4. 4. 5. 6. 6. + 7. 7. 8. 9. 9. + 7. 7. 8. 9. 9. + 10. 10. 11. 12. 12. + 10. 10. 11. 12. 12. + 11. 11. 12. 13. 13. + 11. 11. 12. 13. 13. + 14. 14. 15. 16. 16. + 14. 14. 15. 16. 16. + 17. 17. 18. 19. 19. + 17. 17. 18. 19. 19. + 20. 20. 21. 22. 22. + 20. 20. 21. 22. 22.] (->vec res1))) + (is (= [2 2 4 5] (shape-vec res1))) + (is (= [0. 0. 0. 0. 0. + 0. 1. 2. 3. 0. + 0. 4. 5. 6. 0. + 0. 0. 0. 0. 0. + + 0. 0. 0. 0. 0. + 0. 7. 8. 9. 0. + 0. 10. 11. 12. 0. + 0. 0. 0. 0. 0. + + 0. 0. 0. 0. 0. + 0. 11. 12. 13. 0. + 0. 14. 15. 16. 0. + 0. 0. 0. 0. 0. + + 0. 0. 0. 0. 0. + 0. 17. 18. 19. 0. + 0. 20. 21. 22. 0. + 0. 0. 0. 0. 0.] (->vec res2))) + (is (= [2 2 4 5] (shape-vec res2))))) + +(deftest test-roi-pooling + (let [xi [[[[ 0., 1., 2., 3., 4., 5.], + [ 6., 7., 8., 9., 10., 11.], + [ 12., 13., 14., 15., 16., 17.], + [ 18., 19., 20., 21., 22., 23.], + [ 24., 25., 26., 27., 28., 29.], + [ 30., 31., 32., 33., 34., 35.], + [ 36., 37., 38., 39., 40., 41.], + [ 42., 43., 44., 45., 46., 47.]]]] + x (ndarray/array (-> xi flatten vec) [1 1 8 6]) + y (ndarray/array [0 0 0 4 4] [1 5]) + res1 (ndarray-api/roi-pooling x y [2 2] 1.0) + res2 (ndarray-api/roi-pooling x y [2 2] 0.7)] + (is (= [14. 16. 26. 28.] (->vec res1))) + (is (= [1 1 2 2] (shape-vec res1))) + (is (= [7. 9. 19. 21.] (->vec res2))) + (is (= [1 1 2 2] (shape-vec res2))))) + +(deftest test-reshape + (let [x (ndarray/array (vec (range 4)) [4]) + y (ndarray/array (vec (range 24)) [2 3 4]) + z (ndarray/array (vec (range 120)) [2 3 4 5]) + res1 (ndarray-api/reshape {:data x :shape [2 2]})] + (is (= [0. 1. 2. 3.] (->vec res1))) + (is (= [2 2] (shape-vec res1))) + (is (= (map float (range 24)) (->vec (ndarray-api/reshape {:data y :shape [4 0 2]})))) + (is (= [4 3 2] (shape-vec (ndarray-api/reshape {:data y :shape [4 0 2]})))) + (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 0 0]})))) + (is (= [6 1 4] (shape-vec (ndarray-api/reshape {:data y :shape [6 1 -1]})))) + (is (= [3 1 8] (shape-vec (ndarray-api/reshape {:data y :shape [3 -1 8]})))) + (is (= [24] (shape-vec (ndarray-api/reshape {:data y :shape [-1]})))) + (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-2]})))) + (is (= [2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -2]})))) + (is (= [2 3 4 1 1] (shape-vec (ndarray-api/reshape {:data y :shape [-2 1 1]})))) + (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 4]})))) + (is (= [6 20] (shape-vec (ndarray-api/reshape {:data z :shape [-3 -3]})))) + (is (= [2 12] (shape-vec (ndarray-api/reshape {:data y :shape [0 -3]})))) + (is (= [6 4] (shape-vec (ndarray-api/reshape {:data y :shape [-3 -2]})))) + (is (= [1 2 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [-4 1 2 -2]})))) + (is (= [2 1 3 4] (shape-vec (ndarray-api/reshape {:data y :shape [2 -4 -1 3 -2]})))))) + +(deftest test-sequence-last + (let [xi [[[ 1., 2., 3.], + [ 4., 5., 6.], + [ 7., 8., 9.]], + + [[ 10., 11., 12.], + [ 13., 14., 15.], + [ 16., 17., 18.]], + + [[ 19., 20., 21.], + [ 22., 23., 24.], + [ 25., 26., 27.]]] + x (ndarray/array (-> xi flatten vec) [3 3 3]) + seq-len1 (ndarray/array [1 1 1] [3]) + seq-len2 (ndarray/array [1 2 3] [3]) + ;; This test is failing with an exception + ;; (most likely a scala generation issue) + ;; res1 (ndarray-api/sequence-last x nil) + ] + ;; (is (= [] (->vec res1))) +)) + +(deftest test-sequence-mask + (let [xi [[[ 1., 2., 3.], + [ 4., 5., 6.]], + + [[ 7., 8., 9.], + [ 10., 11., 12.]], + + [[ 13., 14., 15.], + [ 16., 17., 18.]]] + x (ndarray/array (-> xi flatten vec) [3 2 3]) + seq-len1 (ndarray/array [1 1] [2]) + seq-len2 (ndarray/array [2 3] [2]) + ;; Same issue as previous test + ;; res1 (ndarray-api/sequence-mask x seq-len1) + ] + ;; (is (= [] (->vec res1))) +)) + +(deftest test-slice-channel + (let [xi [[[ 1.] [ 2.]] + [[ 3.] [ 4.]] + [[ 5.] [ 6.]]] + x (ndarray/array (-> xi flatten vec) [3 2 1]) + res1 (ndarray-api/slice-channel {:data x :num-outputs 2 :axis 1}) + res2 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0}) + res3 (ndarray-api/slice-channel {:data x :num-outputs 3 :axis 0 :squeeze-axis 1})] + (is (= [1. 3. 5.] (->vec res1))) + (is (= [3 1 1] (shape-vec res1))) + (is (= [1. 2.] (->vec res2))) + (is (= [1 2 1] (shape-vec res2))) + (is (= [1. 2.] (->vec res3))) + (is (= [2 1] (shape-vec res3))))) + +(deftest test-softmax-activation + (let [x (ndarray/array [1 1 1 1 1 1] [2 3]) + res1 (ndarray-api/softmax-activation {:data x :mode "instance"})] + (is (approx= 1e-3 [0.333 0.333 0.333 + 0.333 0.333 0.333] (->vec res1))) + (is (= [2 3] (shape-vec res1))))) + +(deftest test-softmax-output + (let [datai [[1,2,3,4],[2,2,2,2],[3,3,3,3],[4,4,4,4]] + data (ndarray/array (-> datai flatten vec) [4 4]) + label (ndarray/array [1,0,2,3] [4]) + res1 (ndarray-api/softmax-output data label)] + (is (approx= 1e-4 [0.0321 0.0871 0.2369 0.6439 + 0.25 0.25 0.25 0.25 + 0.25 0.25 0.25 0.25 + 0.25 0.25 0.25 0.25] (->vec res1))) + (is (= [4 4] (shape-vec res1))))) + +(deftest test-swap-axis + (let [x (ndarray/array (range 3) [1 3]) + y (ndarray/array (range 8) [2 2 2]) + res1 (ndarray-api/swap-axis {:data x :dim1 0 :dim2 1}) + res2 (ndarray-api/swap-axis {:data y :dim1 0 :dim2 2})] + (is (= [0. 1. 2.] (->vec res1))) + (is (= [3 1] (shape-vec res1))) + (is (= [0. 4. 2. 6. 1. 5. 3. 7.] (->vec res2))) + (is (= [2 2 2] (shape-vec res2))))) + +(deftest test-abs + (let [x (ndarray/array [-2 0 3] [3]) + res1 (ndarray-api/abs {:data x})] + (is (= [2. 0. 3.] (->vec res1))) + (is (= [3] (shape-vec res1))))) + +(deftest test-arccos + (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5]) + pi Math/PI + res1 (ndarray-api/arccos {:data x})] + (is (approx= 1e-3 [pi (* 0.75 pi) (* 0.5 pi) (* 0.25 pi) 0.] (->vec res1))))) + +(deftest test-arcsin + (let [x (ndarray/array [-1 -0.707 0 0.707 1] [5]) + pi Math/PI + res1 (ndarray-api/arcsin {:data x})] + (is (approx= 1e-3 [(- (* 0.5 pi)) (- (* 0.25 pi)) 0 (* 0.25 pi) (* 0.5 pi)] (->vec res1))))) + +(deftest test-argmax + (let [x (ndarray/array (range 6) [2 3]) + res1 (ndarray-api/argmax {:data x :axis 0}) + res2 (ndarray-api/argmax {:data x :axis 1}) + res3 (ndarray-api/argmax {:data x :axis 0 :keepdims true}) + res4 (ndarray-api/argmax {:data x :axis 1 :keepdims true})] + (is (= [1. 1. 1.] (->vec res1))) + (is (= [3] (shape-vec res1))) + (is (= [2. 2.] (->vec res2))) + (is (= [2] (shape-vec res2))) + (is (= [1. 1. 1.] (->vec res3))) + (is (= [1 3] (shape-vec res3))) + (is (= [2. 2.] (->vec res4))) + (is (= [2 1] (shape-vec res4))))) + +(deftest test-argmax-channel + (let [x (ndarray/array (range 6) [2 3]) + res1 (ndarray-api/argmax-channel {:data x})] + (is (= [2. 2.] (->vec res1))) + (is (= [2] (shape-vec res1))))) + +(deftest test-argmin + (let [x (ndarray/array (reverse (range 6)) [2 3]) + res1 (ndarray-api/argmin {:data x :axis 0}) + res2 (ndarray-api/argmin {:data x :axis 1}) + res3 (ndarray-api/argmin {:data x :axis 0 :keepdims true}) + res4 (ndarray-api/argmin {:data x :axis 1 :keepdims true})] + (is (= [1. 1. 1.] (->vec res1))) + (is (= [3] (shape-vec res1))) + (is (= [2. 2.] (->vec res2))) + (is (= [2] (shape-vec res2))) + (is (= [1. 1. 1.] (->vec res3))) + (is (= [1 3] (shape-vec res3))) + (is (= [2. 2.] (->vec res4))) + (is (= [2 1] (shape-vec res4))))) + +(deftest test-argsort + (let [x (ndarray/array [0.3 0.2 0.4 + 0.1 0.3 0.2] + [2 3]) + y (ndarray/array [0.3 0.2 0.4 0.1 0.3 0.2] [6]) + res1 (ndarray-api/argsort {:data x}) + res2 (ndarray-api/argsort {:data x :axis 0}) + res3 (ndarray-api/argsort {:data y})] + (is (= [1. 0. 2. + 0. 2. 1.] + (->vec res1))) + (is (= [2 3] (shape-vec res1))) + (is (= [1. 0. 1. + 0. 1. 0.] + (->vec res2))) + (is (= [2 3] (shape-vec res1))) + (is (= [3. 1. 5. 0. 4. 2.] (->vec res3))) + (is (= [6] (shape-vec res3))))) + +(deftest test-batch-take + (let [x (ndarray/array (range 6) [3 2]) + i (ndarray/as-type (ndarray/array [0 1 0] [3]) dtype/INT32) + res1 (ndarray-api/batch-take x i) ] + (is (= [0. 3. 4.] (->vec res1))))) + +(deftest test-broadcast-add + (let [x (ndarray/ones [2 3]) + y (ndarray/array (range 2) [2 1]) + res1 (ndarray-api/broadcast-add x y)] + (is (= [1. 1. 1. 2. 2. 2.] (->vec res1))) + (is (= [2 3] (shape-vec res1))))) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj new file mode 100644 index 000000000000..b642ad75d1d0 --- /dev/null +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/symbol_api_test.clj @@ -0,0 +1,61 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns org.apache.clojure-mxnet.symbol-api-test + (:require [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.executor :as executor] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.symbol :as sym] + [org.apache.clojure-mxnet.symbol-api :as sym-api] + [org.apache.clojure-mxnet.util :as util] + [clojure.test :refer :all] + [org.apache.clojure-mxnet.context :as context])) + +(deftest test-compose + (let [data (sym/variable "data") + net1 (sym-api/fully-connected {:data data :num-hidden 10 :name "fc1"}) + net1 (sym-api/fully-connected {:data net1 :num-hidden 100 :name "fc2"} ) + + net2 (sym-api/fully-connected {:num-hidden 10 :name "fc3"}) + net2 (sym-api/activation {:data net2 :act-type "relu"}) + net2 (sym-api/fully-connected {:data net2 :num-hidden 20 :name "fc4"}) + + composed (sym/apply net2 "composed" {"fc3_data" net1}) + + multi-out (sym/group [composed net1])] + + (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] (sym/list-arguments net1))) + (is (= 2 (count (sym/list-outputs multi-out)))))) + +(deftest test-symbol-internal + (let [data (sym/variable "data") + oldfc (sym-api/fully-connected {:data data :num-hidden 10 :name"fc1"}) + net1 (sym-api/fully-connected {:data oldfc :num-hidden 100 :name"fc2"})] + (is (= ["data" "fc1_weight" "fc1_bias" "fc2_weight" "fc2_bias"] (sym/list-arguments net1))) + (= (sym/list-arguments oldfc) (-> (sym/get-internals net1) + (sym/get "fc1_output") + (sym/list-arguments))))) + +(deftest test-infer-type + (let [data (sym/variable "data") + f32data (sym-api/cast {:data data :dtype "float32"}) + fc1 (sym-api/fully-connected {:data f32data :num-hidden 128 :name"fc1"}) + mlp (sym-api/softmax-output {:data fc1 :name"softmax"}) + [arg out aux] (sym/infer-type mlp {:data dtype/FLOAT64})] + (is (= [dtype/FLOAT64 dtype/FLOAT32 dtype/FLOAT32 dtype/FLOAT32] (util/buffer->vec arg))) + (is (= [dtype/FLOAT32] (util/buffer->vec out))) + (is (= [] (util/buffer->vec aux)))))