Skip to content

Commit

Permalink
[clojure][generator] ndarray/symbol api random merged (apache#14800)
Browse files Browse the repository at this point in the history
* [clojure][generator] add random namespace generation

* `ndarray_random_api`
* `symbol_random_api`

* fix tests
  • Loading branch information
Chouffe authored and haohuw committed Jun 23, 2019
1 parent 6792c5f commit 913b670
Show file tree
Hide file tree
Showing 9 changed files with 731 additions and 108 deletions.
298 changes: 200 additions & 98 deletions contrib/clojure-package/src/dev/generator.clj

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

(ns org.apache.clojure-mxnet.ndarray-api
"Experimental NDArray API"
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])

(:refer-clojure
:exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.ndarray-random-api
"Experimental NDArray Random API"
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[t6.from-scala.core :refer [$] :as $])
(:import (org.apache.mxnet NDArrayAPI)))

;; loads the generated functions into the namespace
(do (clojure.core/load "gen/ndarray_random_api"))
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;

(ns org.apache.clojure-mxnet.symbol-random-api
"Experimental Symbol Random API"
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle ref])
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.executor :as ex]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[t6.from-scala.core :refer [$] :as $]
[org.apache.clojure-mxnet.ndarray :as ndarray])
(:import (org.apache.mxnet SymbolAPI)))

;; loads the generated functions into the namespace
(do (clojure.core/load "gen/symbol_random_api"))
47 changes: 41 additions & 6 deletions contrib/clojure-package/test/dev/generator_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@
(is (= "foo-bar" (gen/clojure-case "Foo_Bar")))
(is (= "div+" (gen/clojure-case "/+"))))

(deftest fn-name->random-fn-name
(is (= "poisson" (gen/fn-name->random-fn-name "-random-poisson")))
(is (= "poisson-like" (gen/fn-name->random-fn-name "-sample-poisson"))))

(deftest remove-prefix
(is (= "randint" (gen/remove-prefix "-random-" "-random-randint")))
(is (= "exponential" (gen/remove-prefix "-sample-" "-sample-exponential"))))

(deftest in-namespace-random?
(is (gen/in-namespace-random? "random_randint"))
(is (gen/in-namespace-random? "sample_poisson"))
(is (not (gen/in-namespace-random? "rnn")))
(is (not (gen/in-namespace-random? "activation"))))

(defn ndarray-reflect-info [name]
(->> gen/ndarray-public-no-default
(filter #(= name (str (:name %))))
Expand Down Expand Up @@ -317,14 +331,25 @@
(deftest test-write-to-file
(testing "symbol-api"
(let [fname "test/test-symbol-api.clj"
_ (gen/write-to-file [(first gen/all-symbol-api-functions)
(second gen/all-symbol-api-functions)]
gen/symbol-api-gen-ns
fns (gen/all-symbol-api-functions gen/op-names)
_ (gen/write-to-file [(first fns) (second fns)]
(gen/symbol-api-gen-ns false)
fname)
good-contents (slurp "test/good-test-symbol-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))

(testing "symbol-random-api"
(let [fname "test/test-symbol-random-api.clj"
fns (gen/all-symbol-random-api-functions gen/op-names)
_ (gen/write-to-file [(first fns) (second fns)]
(gen/symbol-api-gen-ns true)
fname)
good-contents (slurp "test/good-test-symbol-random-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))


(testing "symbol"
(let [fname "test/test-symbol.clj"
_ (gen/write-to-file [(first gen/all-symbol-functions)]
Expand All @@ -336,14 +361,24 @@

(testing "ndarray-api"
(let [fname "test/test-ndarray-api.clj"
_ (gen/write-to-file [(first gen/all-ndarray-api-functions)
(second gen/all-ndarray-api-functions)]
gen/ndarray-api-gen-ns
fns (gen/all-ndarray-api-functions gen/op-names)
_ (gen/write-to-file [(first fns) (second fns)]
(gen/ndarray-api-gen-ns false)
fname)
good-contents (slurp "test/good-test-ndarray-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))

(testing "ndarray-random-api"
(let [fname "test/test-ndarray-random-api.clj"
fns (gen/all-ndarray-random-api-functions gen/op-names)
_ (gen/write-to-file [(first fns) (second fns)]
(gen/ndarray-api-gen-ns true)
fname)
good-contents (slurp "test/good-test-ndarray-random-api.clj")
contents (slurp fname)]
(is (= good-contents contents))))

(testing "ndarray"
(let [fname "test/test-ndarray.clj"
_ (gen/write-to-file [(first gen/all-ndarray-functions)]
Expand Down
95 changes: 95 additions & 0 deletions contrib/clojure-package/test/good-test-ndarray-random-api.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
(ns
^{:doc "Experimental"}
org.apache.clojure-mxnet.ndarray-random-api
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet NDArrayAPI)))

;; Do not edit - this is auto-generated

;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;




(defn
exponential
"Draw random samples from an exponential distribution.
Samples are distributed according to an exponential distribution parametrized by *lambda* (rate).
Example::
exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364],
[ 0.04146638, 0.31715935]]
Defined in src/operator/random/sample_op.cc:L137
`lam`: Lambda parameter (rate) of the exponential distribution. (optional)
`shape`: Shape of the output. (optional)
`ctx`: Context of output, in format [cpu|gpu|cpu_pinned](n). Only used for imperative calls. (optional)
`dtype`: DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None). (optional)
`out`: Output array. (optional)"
([] (exponential {}))
([{:keys [lam shape ctx dtype out],
:or {lam nil, shape nil, ctx nil, dtype nil, out nil},
:as opts}]
(util/coerce-return
(NDArrayAPI/random_exponential
(util/->option lam)
(util/->option (clojure.core/when shape (mx-shape/->shape shape)))
(util/->option ctx)
(util/->option dtype)
(util/->option out)))))

(defn
gamma
"Draw random samples from a gamma distribution.
Samples are distributed according to a gamma distribution parametrized by *alpha* (shape) and *beta* (scale).
Example::
gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289],
[ 3.91697288, 3.65933681]]
Defined in src/operator/random/sample_op.cc:L125
`alpha`: Alpha parameter (shape) of the gamma distribution. (optional)
`beta`: Beta parameter (scale) of the gamma distribution. (optional)
`shape`: Shape of the output. (optional)
`ctx`: Context of output, in format [cpu|gpu|cpu_pinned](n). Only used for imperative calls. (optional)
`dtype`: DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None). (optional)
`out`: Output array. (optional)"
([] (gamma {}))
([{:keys [alpha beta shape ctx dtype out],
:or {alpha nil, beta nil, shape nil, ctx nil, dtype nil, out nil},
:as opts}]
(util/coerce-return
(NDArrayAPI/random_gamma
(util/->option alpha)
(util/->option beta)
(util/->option (clojure.core/when shape (mx-shape/->shape shape)))
(util/->option ctx)
(util/->option dtype)
(util/->option out)))))

118 changes: 118 additions & 0 deletions contrib/clojure-package/test/good-test-symbol-random-api.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
(ns
^{:doc "Experimental"}
org.apache.clojure-mxnet.symbol-random-api
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle ref])
(:require [org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet SymbolAPI)))

;; Do not edit - this is auto-generated

;; Licensed to the Apache Software Foundation (ASF) under one or more
;; contributor license agreements. See the NOTICE file distributed with
;; this work for additional information regarding copyright ownership.
;; The ASF licenses this file to You under the Apache License, Version 2.0
;; (the "License"); you may not use this file except in compliance with
;; the License. You may obtain a copy of the License at
;;
;; http://www.apache.org/licenses/LICENSE-2.0
;;
;; Unless required by applicable law or agreed to in writing, software
;; distributed under the License is distributed on an "AS IS" BASIS,
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
;; See the License for the specific language governing permissions and
;; limitations under the License.
;;




(defn
exponential
"Draw random samples from an exponential distribution.
Samples are distributed according to an exponential distribution parametrized by *lambda* (rate).
Example::
exponential(lam=4, shape=(2,2)) = [[ 0.0097189 , 0.08999364],
[ 0.04146638, 0.31715935]]
Defined in src/operator/random/sample_op.cc:L137
`lam`: Lambda parameter (rate) of the exponential distribution. (optional)
`shape`: Shape of the output. (optional)
`ctx`: Context of output, in format [cpu|gpu|cpu_pinned](n). Only used for imperative calls. (optional)
`dtype`: DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None). (optional)
`name`: Name of the symbol (optional)
`attr`: Attributes of the symbol (optional)"
[{:keys [lam shape ctx dtype name attr],
:or {lam nil, shape nil, ctx nil, dtype nil, name nil, attr nil},
:as opts}]
(util/coerce-return
(SymbolAPI/random_exponential
(util/->option lam)
(util/->option (clojure.core/when shape (mx-shape/->shape shape)))
(util/->option ctx)
(util/->option dtype)
name
(clojure.core/when
attr
(clojure.core/->>
attr
(clojure.core/mapv
(clojure.core/fn [[k v]] [k (clojure.core/str v)]))
(clojure.core/into {})
util/convert-map)))))

(defn
gamma
"Draw random samples from a gamma distribution.
Samples are distributed according to a gamma distribution parametrized by *alpha* (shape) and *beta* (scale).
Example::
gamma(alpha=9, beta=0.5, shape=(2,2)) = [[ 7.10486984, 3.37695289],
[ 3.91697288, 3.65933681]]
Defined in src/operator/random/sample_op.cc:L125
`alpha`: Alpha parameter (shape) of the gamma distribution. (optional)
`beta`: Beta parameter (scale) of the gamma distribution. (optional)
`shape`: Shape of the output. (optional)
`ctx`: Context of output, in format [cpu|gpu|cpu_pinned](n). Only used for imperative calls. (optional)
`dtype`: DType of the output in case this can't be inferred. Defaults to float32 if not defined (dtype=None). (optional)
`name`: Name of the symbol (optional)
`attr`: Attributes of the symbol (optional)"
[{:keys [alpha beta shape ctx dtype name attr],
:or
{alpha nil,
beta nil,
shape nil,
ctx nil,
dtype nil,
name nil,
attr nil},
:as opts}]
(util/coerce-return
(SymbolAPI/random_gamma
(util/->option alpha)
(util/->option beta)
(util/->option (clojure.core/when shape (mx-shape/->shape shape)))
(util/->option ctx)
(util/->option dtype)
name
(clojure.core/when
attr
(clojure.core/->>
attr
(clojure.core/mapv
(clojure.core/fn [[k v]] [k (clojure.core/str v)]))
(clojure.core/into {})
util/convert-map)))))

Loading

0 comments on commit 913b670

Please sign in to comment.