diff --git a/contrib/clojure-package/examples/cnn-text-classification/README.md b/contrib/clojure-package/examples/cnn-text-classification/README.md index 86a8abb06e7a..19bb9137334b 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/README.md +++ b/contrib/clojure-package/examples/cnn-text-classification/README.md @@ -3,19 +3,19 @@ An example of text classification using CNN To use you must download the MR polarity dataset and put it in the path specified in the mr-dataset-path -The dataset can be obtained here: [https://github.com/yoonkim/CNN_sentence](https://github.com/yoonkim/CNN_sentence). The two files `rt-polarity.neg` +The dataset can be obtained here: [CNN_sentence](https://github.com/yoonkim/CNN_sentence). The two files `rt-polarity.neg` and `rt-polarity.pos` must be put in a directory. For example, `data/mr-data/rt-polarity.neg`. You also must download the glove word embeddings. The suggested one to use is the smaller 50 dimension one -`glove.6B.50d.txt` which is contained in the download file here [https://nlp.stanford.edu/projects/glove/](https://nlp.stanford.edu/projects/glove/) +`glove.6B.50d.txt` which is contained in the download file here: [GloVe](https://nlp.stanford.edu/projects/glove/) ## Usage You can run through the repl with -`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000})` +`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})` or -`JVM_OPTS="Xmx1g" lein run` (cpu) +`JVM_OPTS="-Xmx1g" lein run` (cpu) You can control the devices you run on by doing: @@ -24,10 +24,36 @@ You can control the devices you run on by doing: `lein run :gpu 2` - This will run on 2 gpu devices -The max-examples only loads 1000 each of the dataset to keep the time and memory down. To run all the examples, -change the main to be (train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10) +The max-examples only loads 1000 each of the dataset to keep the time and memory down. To run all the examples, +change the main to be (train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :glove}) and then run - `lein uberjar` - `java -Xms1024m -Xmx2048m -jar target/cnn-text-classification-0.1.0-SNAPSHOT-standalone.jar` + +## Usage with word2vec + +You can also use word2vec embeddings in order to train the text classification model. +Before training, you will need to download [GoogleNews-vectors-negative300.bin](https://drive.google.com/file/d/0B7XkCwpI5KDYNlNUTTlSS21pQmM/edit?usp=sharing) first. +Once you've downloaded the embeddings (which are in a gzipped format), +you'll need to unzip them and place them in the `contrib/clojure-package/data` directory. + +Then you can run training on a subset of examples through the repl using: +``` +(train-convnet {:embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec}) +``` +Note that loading word2vec embeddings consumes memory and takes some time. + +You can also train them using `JVM_OPTS="-Xmx8g" lein run` once you've modified +the parameters to `train-convnet` (see above) in `src/cnn_text_classification/classifier.clj`. +In order to run training with word2vec on the complete data set, you will need to run: +``` +(train-convnet {:embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec}) +``` +You should be able to achieve an accuracy of `~0.78` using the parameters above. + +## Usage with learned embeddings + +Lastly, similar to the python CNN text classification example, you can learn the embeddings based on training data. +This can be achieved by setting `:pretrained-embedding nil` (or omitting that parameter altogether). diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj index 94fd4f518c60..3c0288c9c343 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/classifier.clj @@ -30,34 +30,48 @@ (def data-dir "data/") (def mr-dataset-path "data/mr-data") ;; the MR polarity dataset path -(def glove-file-path "data/glove/glove.6B.50d.txt") (def num-filter 100) (def num-label 2) (def dropout 0.5) - - (when-not (.exists (io/file (str data-dir))) (do (println "Retrieving data for cnn text classification...") (sh "./get_data.sh"))) -(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size embedding-size]}] +(defn shuffle-data [test-num {:keys [data label sentence-count sentence-size vocab-size embedding-size pretrained-embedding]}] (println "Shuffling the data and splitting into training and test sets") (println {:sentence-count sentence-count :sentence-size sentence-size - :embedding-size embedding-size}) + :vocab-size vocab-size + :embedding-size embedding-size + :pretrained-embedding pretrained-embedding}) (let [shuffled (shuffle (map #(vector %1 %2) data label)) train-num (- (count shuffled) test-num) training (into [] (take train-num shuffled)) - test (into [] (drop train-num shuffled))] + test (into [] (drop train-num shuffled)) + ;; has to be channel x y + train-data-shape (if pretrained-embedding + [train-num 1 sentence-size embedding-size] + [train-num 1 sentence-size]) + ;; has to be channel x y + test-data-shape (if pretrained-embedding + [test-num 1 sentence-size embedding-size] + [test-num 1 sentence-size])] {:training {:data (ndarray/array (into [] (flatten (mapv first training))) - [train-num 1 sentence-size embedding-size]) ;; has to be channel x y + train-data-shape) :label (ndarray/array (into [] (flatten (mapv last training))) [train-num])} :test {:data (ndarray/array (into [] (flatten (mapv first test))) - [test-num 1 sentence-size embedding-size]) ;; has to be channel x y + test-data-shape) :label (ndarray/array (into [] (flatten (mapv last test))) [test-num])}})) +(defn get-data-symbol [num-embed sentence-size batch-size vocab-size pretrained-embedding] + (if pretrained-embedding + (sym/variable "data") + (as-> (sym/variable "data") data + (sym/embedding "vocab_embed" {:data data :input-dim vocab-size :output-dim num-embed}) + (sym/reshape {:data data :target-shape [batch-size 1 sentence-size num-embed]})))) + (defn make-filter-layers [{:keys [input-x num-embed sentence-size] :as config} filter-size] (as-> (sym/convolution {:data input-x @@ -71,9 +85,9 @@ ;;; convnet with multiple filter sizes ;; from Convolutional Neural Networks for Sentence Classification by Yoon Kim -(defn get-multi-filter-convnet [num-embed sentence-size batch-size] +(defn get-multi-filter-convnet [num-embed sentence-size batch-size vocab-size pretrained-embedding] (let [filter-list [3 4 5] - input-x (sym/variable "data") + input-x (get-data-symbol num-embed sentence-size batch-size vocab-size pretrained-embedding) polled-outputs (mapv #(make-filter-layers {:input-x input-x :num-embed num-embed :sentence-size sentence-size} %) filter-list) total-filters (* num-filter (count filter-list)) concat (sym/concat "concat" nil polled-outputs {:dim 1}) @@ -82,10 +96,11 @@ fc (sym/fully-connected "fc1" {:data hdrop :num-hidden num-label})] (sym/softmax-output "softmax" {:data fc}))) -(defn train-convnet [{:keys [devs embedding-size batch-size test-size num-epoch max-examples]}] - (let [glove (data-helper/load-glove glove-file-path) ;; you can also use word2vec - ms-dataset (data-helper/load-ms-with-embeddings mr-dataset-path embedding-size glove max-examples) +(defn train-convnet [{:keys [devs embedding-size batch-size test-size + num-epoch max-examples pretrained-embedding]}] + (let [ms-dataset (data-helper/load-ms-with-embeddings mr-dataset-path max-examples embedding-size {:pretrained-embedding pretrained-embedding}) sentence-size (:sentence-size ms-dataset) + vocab-size (:vocab-size ms-dataset) shuffled (shuffle-data test-size ms-dataset) train-data (mx-io/ndarray-iter [(get-in shuffled [:training :data])] {:label [(get-in shuffled [:training :label])] @@ -97,7 +112,7 @@ :label-name "softmax_label" :data-batch-size batch-size :last-batch-handle "pad"})] - (let [mod (m/module (get-multi-filter-convnet embedding-size sentence-size batch-size) {:contexts devs})] + (let [mod (m/module (get-multi-filter-convnet embedding-size sentence-size batch-size vocab-size pretrained-embedding) {:contexts devs})] (println "Getting ready to train for " num-epoch " epochs") (println "===========") (m/fit mod {:train-data train-data :eval-data test-data :num-epoch num-epoch @@ -111,7 +126,7 @@ ;;; omit max-examples if you want to run all the examples in the movie review dataset ;; to limit mem consumption set to something like 1000 and adjust test size to 100 (println "Running with context devices of" devs) - (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000}) + (train-convnet {:devs devs :embedding-size 50 :batch-size 10 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove}) ;; runs all the examples #_(train-convnet {:embedding-size 50 :batch-size 100 :test-size 1000 :num-epoch 10}))) diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj index 79665217744a..82ba13087a37 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj @@ -21,53 +21,84 @@ [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.ndarray :as ndarray] [org.apache.clojure-mxnet.random :as random]) - (:import (java.io DataInputStream)) + (:import (java.io DataInputStream) + (java.nio ByteBuffer ByteOrder)) (:gen-class)) (def w2v-file-path "../../data/GoogleNews-vectors-negative300.bin") ;; the word2vec file path -(def max-vectors 100) ;; If you are using word2vec embeddings and you want to only load part of them - -(defn r-string [dis] - (let [max-size 50 - bs (byte-array max-size) - sb (new StringBuilder)] - (loop [b (.readByte dis) - i 0] - (if (and (not= 32 b) (not= 10 b)) - (do (aset bs i b) - (if (= 49 i) - (do (.append sb (new String bs)) - (recur (.readByte dis) 0)) - (recur (.readByte dis) (inc i)))) - (.append sb (new String bs 0 i)))) - (.toString sb))) - -(defn get-float [b] - (-> 0 - (bit-or (bit-shift-left (bit-and (aget b 0) 0xff) 0)) - (bit-or (bit-shift-left (bit-and (aget b 1) 0xff) 8)) - (bit-or (bit-shift-left (bit-and (aget b 2) 0xff) 16)) - (bit-or (bit-shift-left (bit-and (aget b 3) 0xff) 24)))) +(def EOS "") ;; end of sentence word + +(defn glove-file-path + "Returns the file path to GloVe embedding of the input size" + [embedding-size] + (format "data/glove/glove.6B.%dd.txt" embedding-size)) + +(defn r-string + "Reads a string from the given DataInputStream `dis` until a space or newline is reached." + [dis] + (loop [b (.readByte dis) + bs []] + (if (and (not= 32 b) (not= 10 b)) + (recur (.readByte dis) (conj bs b)) + (new String (byte-array bs))))) + +(defn get-float [bs] + (-> (ByteBuffer/wrap bs) + (.order ByteOrder/LITTLE_ENDIAN) + (.getFloat))) (defn read-float [is] (let [bs (byte-array 4)] (do (.read is bs) (get-float bs)))) -(defn load-google-model [path] - (println "Loading the word2vec model from binary ...") - (with-open [bis (io/input-stream path) - dis (new DataInputStream bis)] - (let [word-size (Integer/parseInt (r-string dis)) - dim (Integer/parseInt (r-string dis)) - _ (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors) - word2vec (reduce (fn [r _] - (assoc r (r-string dis) - (mapv (fn [_] (read-float dis)) (range dim)))) - {} - (range max-vectors))] - (println "Finished") - {:num-embed dim :word2vec word2vec}))) +(defn- load-w2v-vectors + "Lazily loads the word2vec vectors given a data input stream `dis`, + number of words `nwords` and dimensionality `embedding-size`." + [dis embedding-size num-vectors] + (if (= 0 num-vectors) + (list) + (let [word (r-string dis) + vect (mapv (fn [_] (read-float dis)) (range embedding-size))] + (cons [word vect] (lazy-seq (load-w2v-vectors dis embedding-size (dec num-vectors))))))) + +(defn load-word2vec-model + "Loads the word2vec model stored in a binary format from the given `path`. + By default only the first 100 embeddings are loaded." + ([path embedding-size opts] + (println "Loading the word2vec model from binary ...") + (with-open [bis (io/input-stream path) + dis (new DataInputStream bis)] + (let [word-size (Integer/parseInt (r-string dis)) + dim (Integer/parseInt (r-string dis)) + {:keys [max-vectors vocab] :or {max-vectors word-size}} opts + _ (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors) + _ (if (not= embedding-size dim) + (throw (ex-info "Mismatch in embedding size" + {:input-embedding-size embedding-size + :word2vec-embedding-size dim}))) + vectors (load-w2v-vectors dis dim max-vectors) + word2vec (if vocab + (->> vectors + (filter (fn [[w _]] (contains? vocab w))) + (into {})) + (->> vectors + (take max-vectors) + (into {})))] + (println "Finished") + {:num-embed dim :word2vec word2vec}))) + ([path embedding-size] + (load-word2vec-model path embedding-size {:max-vectors 100}))) + +(defn read-text-embedding-pairs [rdr] + (for [^String line (line-seq rdr) + :let [fields (.split line " ")]] + [(aget fields 0) + (mapv #(Float/parseFloat ^String %) (rest fields))])) + +(defn load-glove [glove-file-path] + (println "Loading the glove pre-trained word embeddings from " glove-file-path) + (into {} (read-text-embedding-pairs (io/reader glove-file-path)))) (defn clean-str [s] (-> s @@ -84,9 +115,12 @@ (string/replace #"\)" " ) ") (string/replace #"\?" " ? ") (string/replace #" {2,}" " ") - (string/trim)));; Loads MR polarity data from files, splits the data into words and generates labels. - ;; Returns split sentences and labels. -(defn load-mr-data-and-labels [path max-examples] + (string/trim))) + +(defn load-mr-data-and-labels + "Loads MR polarity data from files, splits the data into words and generates labels. + Returns split sentences and labels." + [path max-examples] (println "Loading all the movie reviews from " path) (let [positive-examples (mapv #(string/trim %) (-> (slurp (str path "/rt-polarity.pos")) (string/split #"\n"))) @@ -104,41 +138,68 @@ negative-labels (mapv (constantly 0) negative-examples)] {:sentences x-text :labels (into positive-labels negative-labels)})) -;; Pads all sentences to the same length. The length is defined by the longest sentence. -;; Returns padded sentences. -(defn pad-sentences [sentences] - (let [padding-word "" +(defn pad-sentences + "Pads all sentences to the same length where the length is defined by the longest sentence. Returns padded sentences." + [sentences] + (let [padding-word EOS sequence-len (apply max (mapv count sentences))] (mapv (fn [s] (let [diff (- sequence-len (count s))] (if (pos? diff) (into s (repeat diff padding-word)) s))) - sentences)));; Map sentences and labels to vectors based on a pretrained embeddings -(defn build-input-data-with-embeddings [sentences embedding-size embeddings] - (mapv (fn [sent] - (mapv (fn [word] (or (get embeddings word) - (ndarray/->vec (random/uniform -0.25 0.25 [embedding-size])))) - sent)) - sentences)) - -(defn load-ms-with-embeddings [path embedding-size embeddings max-examples] - (println "Translating the movie review words into the embeddings") + sentences))) + +(defn build-vocab-embeddings + "Returns the subset of `embeddings` for words from the `vocab`. + Embeddings for words not in the vocabulary are initialized randomly + from a uniform distribution." + [vocab embedding-size embeddings] + (into {} + (mapv (fn [[word _]] + [word (or (get embeddings word) + (ndarray/->vec (random/uniform -0.25 0.25 [embedding-size])))]) + vocab))) + +(defn build-input-data-with-embeddings + "Map sentences and labels to vectors based on a pretrained embeddings." + [sentences embeddings] + (mapv (fn [sent] (mapv #(embeddings %) sent)) sentences)) + +(defn build-vocab + "Creates a vocabulary for the data set based on frequency of words. + Returns a map from words to unique indices." + [sentences] + (let [words (flatten sentences) + wc (reduce + (fn [m w] (update-in m [w] (fnil inc 0))) + {} + words) + sorted-wc (sort-by second > wc) + sorted-w (map first sorted-wc)] + (into {} (map vector sorted-w (range (count sorted-w)))))) + +(defn load-ms-with-embeddings + "Loads the movie review sentences data set for the given + `:pretrained-embedding` (e.g. `nil`, `:glove` or `:word2vec`)" + [path max-examples embedding-size {:keys [pretrained-embedding] + :or {pretrained-embedding nil} + :as opts}] (let [{:keys [sentences labels]} (load-mr-data-and-labels path max-examples) sentences-padded (pad-sentences sentences) - data (build-input-data-with-embeddings sentences-padded embedding-size embeddings)] + vocab (build-vocab sentences-padded) + vocab-embeddings (case pretrained-embedding + :glove (->> (load-glove (glove-file-path embedding-size)) + (build-vocab-embeddings vocab embedding-size)) + :word2vec (->> (load-word2vec-model w2v-file-path embedding-size {:vocab vocab}) + (:word2vec) + (build-vocab-embeddings vocab embedding-size)) + vocab) + data (build-input-data-with-embeddings sentences-padded vocab-embeddings)] {:data data :label labels :sentence-count (count data) :sentence-size (count (first data)) - :embedding-size embedding-size})) - -(defn read-text-embedding-pairs [rdr] - (for [^String line (line-seq rdr) - :let [fields (.split line " ")]] - [(aget fields 0) - (mapv #(Double/parseDouble ^String %) (rest fields))])) - -(defn load-glove [glove-file-path] - (println "Loading the glove pre-trained word embeddings from " glove-file-path) - (into {} (read-text-embedding-pairs (io/reader glove-file-path)))) + :embedding-size embedding-size + :vocab-size (count vocab) + :pretrained-embedding pretrained-embedding})) diff --git a/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj index 918a46f474d8..744307e3e363 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj +++ b/contrib/clojure-package/examples/cnn-text-classification/test/cnn_text_classification/classifier_test.clj @@ -16,29 +16,33 @@ ;; (ns cnn-text-classification.classifier-test - (:require - [clojure.test :refer :all] - [org.apache.clojure-mxnet.module :as module] - [org.apache.clojure-mxnet.ndarray :as ndarray] - [org.apache.clojure-mxnet.util :as util] - [org.apache.clojure-mxnet.context :as context] - [cnn-text-classification.classifier :as classifier])) + (:require [clojure.test :refer :all] + [org.apache.clojure-mxnet.module :as module] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.util :as util] + [org.apache.clojure-mxnet.context :as context] + [cnn-text-classification.classifier :as classifier])) -; -; The one and unique classifier test -; -(deftest classifier-test - (let [train - (classifier/train-convnet - {:devs [(context/default-context)] - :embedding-size 50 - :batch-size 10 - :test-size 100 - :num-epoch 1 - :max-examples 1000})] +(deftest classifier-with-embeddings-test + (let [train (classifier/train-convnet + {:devs [(context/default-context)] + :embedding-size 50 + :batch-size 10 + :test-size 100 + :num-epoch 1 + :max-examples 1000 + :pretrained-embedding :glove})] (is (= ["data"] (util/scala-vector->vec (module/data-names train)))) - (is (= 20 (count (ndarray/->vec (-> train module/outputs first first))))))) - ;(prn (util/scala-vector->vec (data-shapes train))) - ;(prn (util/scala-vector->vec (label-shapes train))) - ;(prn (output-names train)) - ;(prn (output-shapes train)) \ No newline at end of file + (is (= 20 (count (ndarray/->vec (-> train module/outputs ffirst))))))) + +(deftest classifier-without-embeddings-test + (let [train (classifier/train-convnet + {:devs [(context/default-context)] + :embedding-size 50 + :batch-size 10 + :test-size 100 + :num-epoch 1 + :max-examples 1000 + :pretrained-embedding nil})] + (is (= ["data"] (util/scala-vector->vec (module/data-names train)))) + (is (= 20 (count (ndarray/->vec (-> train module/outputs ffirst)))))))