Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[clojure][generator][WIP] add random namespace generation #14750

Closed
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 110 additions & 3 deletions contrib/clojure-package/src/dev/generator.clj
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@
(clojure.pprint/pprint f w)
(.write w "\n"))))

(defn remove-prefix
[prefix s]
(let [regex (re-pattern (str prefix "(.*)"))
replacement "$1"]
(clojure.string/replace s regex replacement)))

;;;;;;; Common operations

(def libinfo (Base/_LIB))
Expand Down Expand Up @@ -470,6 +476,53 @@
(println "Generating symbol-api file")
(write-to-file all-symbol-api-functions symbol-api-gen-ns "src/org/apache/clojure_mxnet/gen/symbol_api.clj"))

;;;;;;; SymbolRandomAPI

(defn gen-symbol-random-api-function [op-name]
Chouffe marked this conversation as resolved.
Show resolved Hide resolved
(let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
fn-name (remove-prefix "-random-" fn-name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

params (mapv (fn [{:keys [name type optional?] :as opts}]
(assoc opts
:sym (symbol name)
:optional? (or optional?
(= "NDArray-or-Symbol" type))))
(conj args
{:name "name"
:type "String"
:optional? true
:description "Name of the symbol"}
{:name "attr"
:type "Map[String, String]"
:optional? true
:description "Attributes of the symbol"}))
doc (gen-symbol-api-doc fn-description params)
default-call (gen-symbol-api-default-arity op-name params)]
`(~'defn ~(symbol fn-name)
~doc
~@default-call)))

(def all-symbol-random-api-functions
(->> op-names
(filter #(clojure.string/includes? % "random_"))
(mapv gen-symbol-random-api-function)))

(def symbol-random-api-gen-ns "(ns
^{:doc \"Experimental\"}
org.apache.clojure-mxnet.symbol-random-api
(:refer-clojure :exclude [* - + > >= < <= / cast concat identity flatten load max
min repeat reverse set sort take to-array empty sin
get apply shuffle ref])
(:require [org.apache.clojure-mxnet.util :as util]
[org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet SymbolAPI)))")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could not get it to work. Do you have any pointers on how this should be done?


(defn generate-symbol-random-api-file []
(println "Generating symbol-random-api file")
(write-to-file all-symbol-random-api-functions
symbol-random-api-gen-ns
"src/org/apache/clojure_mxnet/gen/symbol_random_api.clj"))


;;;;;;; NDArrayAPI

(defn ndarray-api-coerce-param
Expand Down Expand Up @@ -535,7 +588,9 @@
~default-call))))

(def all-ndarray-api-functions
(mapv gen-ndarray-api-function op-names))
(->> op-names
(remove #(clojure.string/includes? % "random_"))
(mapv gen-ndarray-api-function)))

(def ndarray-api-gen-ns "(ns
^{:doc \"Experimental\"}
Expand All @@ -554,12 +609,60 @@
ndarray-api-gen-ns
"src/org/apache/clojure_mxnet/gen/ndarray_api.clj"))

;;;;;;; NDArrayRandomAPI

(defn gen-ndarray-random-api-function [op-name]
(let [{:keys [fn-name fn-description args]} (gen-op-info op-name)
fn-name (remove-prefix "-random-" fn-name)
params (mapv (fn [{:keys [name] :as opts}]
(assoc opts :sym (symbol name)))
(conj args {:name "out"
:type "NDArray-or-Symbol"
:optional? true
:description "Output array."}))
doc (gen-ndarray-api-doc fn-description params)
opt-params (filter :optional? params)
req-params (remove :optional? params)
req-call (gen-ndarray-api-required-arity fn-name req-params)
default-call (gen-ndarray-api-default-arity op-name params)]
(if (= 1 (count req-params))
`(~'defn ~(symbol fn-name)
~doc
~@default-call)
`(~'defn ~(symbol fn-name)
~doc
~req-call
~default-call))))

(def all-ndarray-random-api-functions
(->> op-names
(filter #(clojure.string/includes? % "random_"))
(mapv gen-ndarray-random-api-function)))

(def ndarray-random-api-gen-ns "(ns
^{:doc \"Experimental\"}
org.apache.clojure-mxnet.ndarray-random-api
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet NDArrayAPI)))")

(defn generate-ndarray-random-api-file []
(println "Generating ndarray-random-api file")
(write-to-file all-ndarray-random-api-functions
ndarray-random-api-gen-ns
"src/org/apache/clojure_mxnet/gen/ndarray_random_api.clj"))

;;; autogen the files
(do
(generate-ndarray-file)
(generate-ndarray-api-file)
(generate-ndarray-random-api-file)
(generate-symbol-file)
(generate-symbol-api-file))
(generate-symbol-api-file)
(generate-symbol-random-api-file))


(comment
Expand All @@ -570,8 +673,12 @@

(gen-symbol-api-function "Activation")

(gen-ndarray-random-api-function "random_randint")

(gen-symbol-random-api-function "random_poisson")

;; This generates a file with the bulk of the nd-array functions
(generate-ndarray-file)

;; This generates a file with the bulk of the symbol functions
(generate-symbol-file) )
(generate-symbol-file))
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
(ns org.apache.clojure-mxnet.ndarray-random-api
"Experimental NDArray Random API"
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[t6.from-scala.core :refer [$] :as $])
(:import (org.apache.mxnet NDArrayAPI)))

;; loads the generated functions into the namespace
(do (clojure.core/load "gen/ndarray_random_api"))