Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[clojure-package] add ->nd-vec function in ndarray.clj #14308

Merged
merged 2 commits into from
Mar 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@
;;

(ns org.apache.clojure-mxnet.ndarray
"NDArray API for Clojure package."
(:refer-clojure :exclude [* - + > >= < <= / cast concat flatten identity load max
min repeat reverse set sort take to-array empty shuffle
ref])
(:require [org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[clojure.reflect :as r]
[t6.from-scala.core :refer [$] :as $])
(:require
[clojure.spec.alpha :as s]

[org.apache.clojure-mxnet.base :as base]
[org.apache.clojure-mxnet.context :as mx-context]
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.util :as util]
[t6.from-scala.core :refer [$] :as $])
(:import (org.apache.mxnet NDArray)))

;; loads the generated functions into the namespace
Expand Down Expand Up @@ -167,3 +170,46 @@

(defn shape-vec [ndarray]
(mx-shape/->vec (shape ndarray)))

(s/def ::ndarray #(instance? NDArray %))
(s/def ::vector vector?)
(s/def ::sequential sequential?)
(s/def ::shape-vec-match-vec
(fn [[v vec-shape]] (= (count v) (reduce clojure.core/* 1 vec-shape))))

(s/fdef vec->nd-vec
:args (s/cat :v ::sequential :shape-vec ::sequential)
:ret ::vector)

(defn- vec->nd-vec
"Convert a vector `v` into a n-dimensional vector given the `shape-vec`
Ex:
(vec->nd-vec [1 2 3] [1 1 3]) ;[[[1 2 3]]]
(vec->nd-vec [1 2 3 4 5 6] [2 3 1]) ;[[[1] [2] [3]] [[4] [5] [6]]]
(vec->nd-vec [1 2 3 4 5 6] [1 2 3]) ;[[[1 2 3]] [4 5 6]]]
(vec->nd-vec [1 2 3 4 5 6] [3 1 2]) ;[[[1 2]] [[3 4]] [[5 6]]]
(vec->nd-vec [1 2 3 4 5 6] [3 2]) ;[[1 2] [3 4] [5 6]]"
[v [s1 & ss :as shape-vec]]
(util/validate! ::sequential v "Invalid input vector `v`")
(util/validate! ::sequential shape-vec "Invalid input vector `shape-vec`")
(util/validate! ::shape-vec-match-vec
[v shape-vec]
"Mismatch between vector `v` and vector `shape-vec`")
(if-not (seq ss)
(vec v)
(->> v
(partition (clojure.core// (count v) s1))
vec
(mapv #(vec->nd-vec % ss)))))

(s/fdef ->nd-vec :args (s/cat :ndarray ::ndarray) :ret ::vector)

(defn ->nd-vec
"Convert an ndarray `ndarray` into a n-dimensional Clojure vector.
Ex:
(->nd-vec (array [1] [1 1 1])) ;[[[1.0]]]
(->nd-vec (array [1 2 3] [3 1 1])) ;[[[1.0]] [[2.0]] [[3.0]]]
(->nd-vec (array [1 2 3 4 5 6]) [3 1 2]) ;[[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]"
[ndarray]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the function spec, do you think it would also be nice to be able to do a util/validate in here to check that it is a NDArray? when called without instrumentation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes lets do that!

(util/validate! ::ndarray ndarray "Invalid input array")
(vec->nd-vec (->vec ndarray) (shape-vec ndarray)))
Original file line number Diff line number Diff line change
Expand Up @@ -473,3 +473,15 @@
(is (= [2 2] (ndarray/->int-vec nda)))
(is (= [2.0 2.0] (ndarray/->double-vec nda)))
(is (= [(byte 2) (byte 2)] (ndarray/->byte-vec nda)))))

(deftest test->nd-vec
(is (= [[[1.0]]]
(ndarray/->nd-vec (ndarray/array [1] [1 1 1]))))
(is (= [[[1.0]] [[2.0]] [[3.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3] [3 1 1]))))
(is (= [[[1.0 2.0]] [[3.0 4.0]] [[5.0 6.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 1 2]))))
(is (= [[[1.0] [2.0]] [[3.0] [4.0]] [[5.0] [6.0]]]
(ndarray/->nd-vec (ndarray/array [1 2 3 4 5 6] [3 2 1]))))
(is (thrown-with-msg? Exception #"Invalid input array"
(ndarray/->nd-vec [1 2 3 4 5]))))