Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

#13441 [Clojure] Add Spec Validations for the Random namespace #13523

Merged
merged 1 commit into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
;;;train

;;initialize with random noise
img (ndarray/- (random/uniform 0 255 content-np-shape dev) 128)
img (ndarray/- (random/uniform 0 255 content-np-shape {:ctx dev}) 128)
;;; img (random/uniform -0.1 0.1 content-np-shape dev)
;; img content-np
lr-sched (lr-scheduler/factor-scheduler 10 0.9)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@
(org.apache.mxnet.optimizer SGD DCASGD NAG AdaDelta RMSProp AdaGrad Adam SGLD)
(org.apache.mxnet FactorScheduler)))

(s/def ::learning-rate float?)
(s/def ::momentum float?)
(s/def ::wd float?)
(s/def ::clip-gradient float?)
(s/def ::lr-scheduler #(instance? FactorScheduler))
(s/def ::learning-rate number?)
(s/def ::momentum number?)
(s/def ::wd number?)
(s/def ::clip-gradient number?)
(s/def ::lr-scheduler #(instance? FactorScheduler %))
(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd ::clip-gradient ::lr-scheduler]))

(defn sgd
Expand All @@ -43,7 +43,7 @@
([]
(sgd {})))

(s/def ::lambda float?)
(s/def ::lambda number?)
(s/def ::dcasgd-opts (s/keys :opt-un [::learning-rate ::momentum ::lambda ::wd ::clip-gradient ::lr-scheduler]))

(defn dcasgd
Expand Down Expand Up @@ -77,9 +77,9 @@
([]
(nag {})))

(s/def ::rho float?)
(s/def ::rescale-gradient float?)
(s/def ::epsilon float?)
(s/def ::rho number?)
(s/def ::rescale-gradient number?)
(s/def ::epsilon number?)
(s/def ::ada-delta-opts (s/keys :opt-un [::rho ::rescale-gradient ::epsilon ::wd ::clip-gradient]))

(defn ada-delta
Expand All @@ -96,8 +96,8 @@
([]
(ada-delta {})))

(s/def gamma1 float?)
(s/def gamma2 float?)
(s/def gamma1 number?)
(s/def gamma2 number?)
(s/def ::rms-prop-opts (s/keys :opt-un [::learning-rate ::rescale-gradient ::gamma1 ::gamma2 ::wd ::clip-gradient]))

(defn rms-prop
Expand Down Expand Up @@ -144,8 +144,8 @@
([]
(ada-grad {})))

(s/def ::beta1 float?)
(s/def ::beta2 float?)
(s/def ::beta1 number?)
(s/def ::beta2 number?)
(s/def ::adam-opts (s/keys :opt-un [::learning-rate ::beta1 ::beta2 ::epsilon ::decay-factor ::wd ::clip-gradient ::lr-scheduler]))

(defn adam
Expand Down
30 changes: 27 additions & 3 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,18 @@
;;

(ns org.apache.clojure-mxnet.random
(:require [org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet Random)))
(:require
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.context :as context]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Context Random)))

(s/def ::low number?)
(s/def ::high number?)
(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
(s/def ::ctx #(instance? Context %))
(s/def ::uniform-opts (s/keys :opt-un [::ctx]))

(defn uniform
"Generate uniform distribution in [low, high) with shape.
Expand All @@ -29,10 +39,18 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([low high shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
(util/validate! ::low low "Incorrect random uniform parameter")
(util/validate! ::high high "Incorrect random uniform parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out))
([low high shape-vec]
(uniform low high shape-vec {})))

(s/def ::loc number?)
(s/def ::scale number?)
(s/def ::normal-opts (s/keys :opt-un [::ctx]))

(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
Expand All @@ -43,10 +61,15 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([loc scale shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::normal-opts opts "Incorrect random normal parameters")
(util/validate! ::loc loc "Incorrect random normal parameters")
(util/validate! ::scale scale "Incorrect random normal parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))

(s/def ::seed-state number?)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
Expand All @@ -58,4 +81,5 @@
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
[seed-state]
(Random/seed (int seed-state)))
(util/validate! ::seed-state seed-state "Incorrect seed parameters")
(Random/seed (int seed-state)))
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
test (sym/transpose data)
shape-vec [3 4]
ctx (context/default-context)
arr-data (random/uniform 0 100 shape-vec ctx)
arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
trans (ndarray/transpose (ndarray/copy arr-data))
exec-test (sym/bind test ctx {"data" arr-data})
out (-> (executor/forward exec-test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
(let [[a b] [-10 10]
shape [100 100]
_ (random/seed 128)
un1 (random/uniform a b shape {:context ctx})
un1 (random/uniform a b shape {:ctx ctx})
_ (random/seed 128)
un2 (random/uniform a b shape {:context ctx})]
un2 (random/uniform a b shape {:ctx ctx})]
(is (= un1 un2))
(is (< (Math/abs
(/ (/ (apply + (ndarray/->vec un1))
Expand All @@ -52,3 +52,16 @@
(is (< (Math/abs (- mean mu)) 0.1))
(is (< (Math/abs (- stddev sigma)) 0.1)))))

(defn random-or-normal [fn_]
(is (thrown? Exception (fn_ 'a 2 [])))
(is (thrown? Exception (fn_ 1 'b [])))
(is (thrown? Exception (fn_ 1 2 [-1])))
(is (thrown? Exception (fn_ 1 2 [2 3 0])))
(is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
(let [ctx (context/default-context)]
(is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))

(deftest test-random-parameters-specs
(random-or-normal random/normal)
(random-or-normal random/uniform)
(is (thrown? Exception (random/seed "a"))))