Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[clojure-package][wip] add ->nd-vec function in ndarray.clj (#14308)
Browse files Browse the repository at this point in the history
* [clojure-package][wip] add `->nd-vec` function in `ndarray.clj`

* WIP
* Unit tests need to be added

* [clojure-package][ndarray] add unit tests for `->nd-vec` util fn
  • Loading branch information
Chouffe authored and gigasquid committed Mar 10, 2019
1 parent c645591 commit 8be97d7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
58 changes: 52 additions & 6 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/ndarray.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
;;

(ns org.apache.clojure-mxnet.ndarray
"NDArray API for Clojure package."
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[t6.from-scala.core :refer [$] :as $])
(:require
[clojure.spec.alpha :as s]

[org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[t6.from-scala.core :refer [$] :as $])
(:import (org.apache.mxnet NDArray)))

;; loads the generated functions into the namespace
Expand Down Expand Up @@ -167,3 +170,46 @@

(defn shape-vec [ndarray]
(mx-shape/->vec (shape ndarray)))

(s/def ::ndarray #(instance? NDArray %))
(s/def ::vector vector?)
(s/def ::sequential sequential?)
(s/def ::shape-vec-match-vec
(fn [[v vec-shape]] (= (count v) (reduce clojure.core/* 1 vec-shape))))

(s/fdef vec->nd-vec
:args (s/cat :v ::sequential :shape-vec ::sequential)
:ret ::vector)

(defn- vec->nd-vec
"Convert a vector `v` into a n-dimensional vector given the `shape-vec`
Ex:
(vec->nd-vec [1 2 3] [1 1 3]) ;[[[1 2 3]]]
(vec->nd-vec [1 2 3 4 5 6] [2 3 1]) ;[[[1] [2] [3]] [[4] [5] [6]]]
(vec->nd-vec [1 2 3 4 5 6] [1 2 3]) ;[[[1 2 3]] [4 5 6]]]
(vec->nd-vec [1 2 3 4 5 6] [3 1 2]) ;[[[1 2]] [[3 4]] [[5 6]]]
(vec->nd-vec [1 2 3 4 5 6] [3 2]) ;[[1 2] [3 4] [5 6]]"
[v [s1 & ss :as shape-vec]]
(util/validate! ::sequential v "Invalid input vector `v`")
(util/validate! ::sequential shape-vec "Invalid input vector `shape-vec`")
(util/validate! ::shape-vec-match-vec
[v shape-vec]
"Mismatch between vector `v` and vector `shape-vec`")
(if-not (seq ss)
(vec v)
(->> v
(partition (clojure.core// (count v) s1))
vec
(mapv #(vec->nd-vec % ss)))))

(s/fdef ->nd-vec :args (s/cat :ndarray ::ndarray) :ret ::vector)

(defn ->nd-vec
"Convert an ndarray `ndarray` into a n-dimensional Clojure vector.
Ex:
(->nd-vec (array [1] [1 1 1])) ;[[[1.0]]]
(->nd-vec (array [1 2 3] [3 1 1])) ;[[[1.0]] [[2.0]] [[3.0]]]
(->nd-vec (array [1 2 3 4 5 6]) [3 1 2]) ;[[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]"
[ndarray]
(util/validate! ::ndarray ndarray "Invalid input array")
(vec->nd-vec (->vec ndarray) (shape-vec ndarray)))
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,15 @@
(is (= [2 2] (ndarray/->int-vec nda)))
(is (= [2.0 2.0] (ndarray/->double-vec nda)))
(is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda)))))

(deftest test->nd-vec
(is (= [[[1.0]]]
(ndarray/->nd-vec (ndarray/array [1] [1 1 1]))))
(is (= [[[1.0]] [[2.0]] [[3.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3] [3 1 1]))))
(is (= [[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 1 2]))))
(is (= [[[1.0] [2.0]] [[3.0] [4.0]] [[5.0] [6.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 2 1]))))
(is (thrown-with-msg? Exception #"Invalid input array"
(ndarray/->nd-vec [1 2 3 4 5]))))

0 comments on commit 8be97d7

Please sign in to comment.