diff --git a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj index 50f95c9750ee..fcf402f3466d 100644 --- a/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj +++ b/contrib/clojure-package/examples/neural-style/src/neural_style/core.clj @@ -193,7 +193,7 @@ ;;;train ;;initialize with random noise - img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128) + img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128) ;;; img (random/uniform -0.1 0.1 content-np-shape dev) ;; img content-np lr-sched (lr-scheduler/factor-scheduler 10 0.9) diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj index f77f5532bfb1..672090a899b3 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/optimizer.clj @@ -24,11 +24,11 @@ (org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam SGLD) (org.apache.mxnet FactorScheduler))) -(s/def ::learning-rate float?) -(s/def ::momentum float?) -(s/def ::wd float?) -(s/def ::clip-gradient float?) -(s/def ::lr-scheduler #(instance? FactorScheduler)) +(s/def ::learning-rate number?) +(s/def ::momentum number?) +(s/def ::wd number?) +(s/def ::clip-gradient number?) +(s/def ::lr-scheduler #(instance? FactorScheduler %)) (s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd ::clip-gradient ::lr-scheduler])) (defn sgd @@ -43,7 +43,7 @@ ([] (sgd {}))) -(s/def ::lambda float?) +(s/def ::lambda number?) (s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd ::clip-gradient ::lr-scheduler])) (defn dcasgd @@ -77,9 +77,9 @@ ([] (nag {}))) -(s/def ::rho float?) -(s/def ::rescale-gradient float?) -(s/def ::epsilon float?) +(s/def ::rho number?) +(s/def ::rescale-gradient number?) +(s/def ::epsilon number?) (s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon ::wd ::clip-gradient])) (defn ada-delta @@ -96,8 +96,8 @@ ([] (ada-delta {}))) -(s/def gamma1 float?) -(s/def gamma2 float?) +(s/def gamma1 number?) +(s/def gamma2 number?) (s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::gamma1 ::gamma2 ::wd ::clip-gradient])) (defn rms-prop @@ -144,8 +144,8 @@ ([] (ada-grad {}))) -(s/def ::beta1 float?) -(s/def ::beta2 float?) +(s/def ::beta1 number?) +(s/def ::beta2 number?) (s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon ::decay-factor ::wd ::clip-gradient ::lr-scheduler])) (defn adam diff --git a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj index d6e33789a629..0ec2039ba79b 100644 --- a/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj +++ b/contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj @@ -16,8 +16,18 @@ ;; (ns org.apache.clojure-mxnet.random - (:require [org.apache.clojure-mxnet.shape :as mx-shape]) - (:import (org.apache.mxnet Random))) + (:require + [org.apache.clojure-mxnet.shape :as mx-shape] + [org.apache.clojure-mxnet.context :as context] + [clojure.spec.alpha :as s] + [org.apache.clojure-mxnet.util :as util]) + (:import (org.apache.mxnet Context Random))) + +(s/def ::low number?) +(s/def ::high number?) +(s/def ::shape-vec (s/coll-of pos-int? :kind vector?)) +(s/def ::ctx #(instance? Context %)) +(s/def ::uniform-opts (s/keys :opt-un [::ctx])) (defn uniform "Generate uniform distribution in [low, high) with shape. @@ -29,10 +39,18 @@ out: Output place holder} returns: The result ndarray with generated result./" ([low high shape-vec {:keys [ctx out] :as opts}] + (util/validate! ::uniform-opts opts "Incorrect random uniform parameters") + (util/validate! ::low low "Incorrect random uniform parameter") + (util/validate! ::high high "Incorrect random uniform parameters") + (util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters") (Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out)) ([low high shape-vec] (uniform low high shape-vec {}))) +(s/def ::loc number?) +(s/def ::scale number?) +(s/def ::normal-opts (s/keys :opt-un [::ctx])) + (defn normal "Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape. loc: The standard deviation of the normal distribution @@ -43,10 +61,15 @@ out: Output place holder} returns: The result ndarray with generated result./" ([loc scale shape-vec {:keys [ctx out] :as opts}] + (util/validate! ::normal-opts opts "Incorrect random normal parameters") + (util/validate! ::loc loc "Incorrect random normal parameters") + (util/validate! ::scale scale "Incorrect random normal parameters") + (util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters") (Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out)) ([loc scale shape-vec] (normal loc scale shape-vec {}))) +(s/def ::seed-state number?) (defn seed " Seed the random number generators in mxnet. This seed will affect behavior of functions in this module, @@ -58,4 +81,5 @@ This means if you set the same seed, the random number sequence generated from GPU0 can be different from CPU." [seed-state] - (Random/seed (int seed-state))) + (util/validate! ::seed-state seed-state "Incorrect seed parameters") + (Random/seed (int seed-state))) \ No newline at end of file diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj index 1b4b2ea2fbe3..c97711b5fed6 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/operator_test.clj @@ -462,7 +462,7 @@ test (sym/transpose data) shape-vec [3 4] ctx (context/default-context) - arr-data (random/uniform 0 100 shape-vec ctx) + arr-data (random/uniform 0 100 shape-vec {:ctx ctx}) trans (ndarray/transpose (ndarray/copy arr-data)) exec-test (sym/bind test ctx {"data" arr-data}) out (-> (executor/forward exec-test) diff --git a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj index c4e9198073a8..6952335c1390 100644 --- a/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj +++ b/contrib/clojure-package/test/org/apache/clojure_mxnet/random_test.clj @@ -26,9 +26,9 @@ (let [[a b] [-10 10] shape [100 100] _ (random/seed 128) - un1 (random/uniform a b shape {:context ctx}) + un1 (random/uniform a b shape {:ctx ctx}) _ (random/seed 128) - un2 (random/uniform a b shape {:context ctx})] + un2 (random/uniform a b shape {:ctx ctx})] (is (= un1 un2)) (is (< (Math/abs (/ (/ (apply + (ndarray/->vec un1)) @@ -52,3 +52,16 @@ (is (< (Math/abs (- mean mu)) 0.1)) (is (< (Math/abs (- stddev sigma)) 0.1))))) +(defn random-or-normal [fn_] + (is (thrown? Exception (fn_ 'a 2 []))) + (is (thrown? Exception (fn_ 1 'b []))) + (is (thrown? Exception (fn_ 1 2 [-1]))) + (is (thrown? Exception (fn_ 1 2 [2 3 0]))) + (is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"}))) + (let [ctx (context/default-context)] + (is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx})))))) + +(deftest test-random-parameters-specs + (random-or-normal random/normal) + (random-or-normal random/uniform) + (is (thrown? Exception (random/seed "a")))) \ No newline at end of file