diff --git a/contrib/clojure-package/examples/cnn-text-classification/README.md b/contrib/clojure-package/examples/cnn-text-classification/README.md index f2ed939bee16..8f8e6200ec7c 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/README.md +++ b/contrib/clojure-package/examples/cnn-text-classification/README.md @@ -29,8 +29,7 @@ You also must download the glove word embeddings. The suggested one to use is th ## Usage You can run through the repl with -`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})` - +`(train-convnet {:devs [(context/default-context)] :embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})` or `JVM_OPTS="-Xmx1g" lein run` (cpu) @@ -49,6 +48,21 @@ and then run - `lein uberjar` - `java -Xms1024m -Xmx2048m -jar target/cnn-text-classification-0.1.0-SNAPSHOT-standalone.jar` +## Usage with fastText + +Using fastText instead of glove is fairly straightforward, as the pretrained embedding format is very similar. + +Download the 'Simple English' pretrained wiki word vectors (text) from the fastText +[site](https://fasttext.cc/docs/en/pretrained-vectors.html) and place them in the +`data/fasttext` directory. Alternatively just run `./get_fasttext_data.sh`. + +Then you can run training on a subset of examples through the repl using: +``` +(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :fasttext}) +``` + +Expect a validation accuracy of `~0.67` with the above parameters. + ## Usage with word2vec You can also use word2vec embeddings in order to train the text classification model. @@ -58,7 +72,7 @@ you'll need to unzip them and place them in the `contrib/clojure-package/data` d Then you can run training on a subset of examples through the repl using: ``` -(train-convnet {:embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec}) +(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec}) ``` Note that loading word2vec embeddings consumes memory and takes some time. @@ -66,7 +80,7 @@ You can also train them using `JVM_OPTS="-Xmx8g" lein run` once you've modified the parameters to `train-convnet` (see above) in `src/cnn_text_classification/classifier.clj`. In order to run training with word2vec on the complete data set, you will need to run: ``` -(train-convnet {:embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec}) +(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec}) ``` You should be able to achieve an accuracy of `~0.78` using the parameters above. diff --git a/contrib/clojure-package/examples/cnn-text-classification/get_fasttext_data.sh b/contrib/clojure-package/examples/cnn-text-classification/get_fasttext_data.sh new file mode 100755 index 000000000000..2bfe96659402 --- /dev/null +++ b/contrib/clojure-package/examples/cnn-text-classification/get_fasttext_data.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +mkdir -p data/fasttext +cd data/fasttext +wget https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.vec diff --git a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj index 82ba13087a37..df132c3167cd 100644 --- a/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj +++ b/contrib/clojure-package/examples/cnn-text-classification/src/cnn_text_classification/data_helper.clj @@ -33,6 +33,8 @@ [embedding-size] (format "data/glove/glove.6B.%dd.txt" embedding-size)) +(def fasttext-file-path "data/fasttext/wiki.simple.vec") + (defn r-string "Reads a string from the given DataInputStream `dis` until a space or newline is reached." [dis] @@ -62,7 +64,7 @@ vect (mapv (fn [_] (read-float dis)) (range embedding-size))] (cons [word vect] (lazy-seq (load-w2v-vectors dis embedding-size (dec num-vectors))))))) -(defn load-word2vec-model +(defn load-word2vec-model! "Loads the word2vec model stored in a binary format from the given `path`. By default only the first 100 embeddings are loaded." ([path embedding-size opts] @@ -75,8 +77,8 @@ _ (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors) _ (if (not= embedding-size dim) (throw (ex-info "Mismatch in embedding size" - {:input-embedding-size embedding-size - :word2vec-embedding-size dim}))) + {:input-embedding-size embedding-size + :word2vec-embedding-size dim}))) vectors (load-w2v-vectors dis dim max-vectors) word2vec (if vocab (->> vectors @@ -88,17 +90,30 @@ (println "Finished") {:num-embed dim :word2vec word2vec}))) ([path embedding-size] - (load-word2vec-model path embedding-size {:max-vectors 100}))) + (load-word2vec-model! path embedding-size {:max-vectors 100}))) -(defn read-text-embedding-pairs [rdr] - (for [^String line (line-seq rdr) +(defn read-text-embedding-pairs [pairs] + (for [^String line pairs :let [fields (.split line " ")]] [(aget fields 0) (mapv #(Float/parseFloat ^String %) (rest fields))])) -(defn load-glove [glove-file-path] +(defn load-glove! [glove-file-path] (println "Loading the glove pre-trained word embeddings from " glove-file-path) - (into {} (read-text-embedding-pairs (io/reader glove-file-path)))) + (->> (io/reader glove-file-path) + line-seq + read-text-embedding-pairs + (into {}))) + +(def remove-fasttext-metadata rest) + +(defn load-fasttext! [fasttext-file-path] + (println "Loading the fastText pre-trained word embeddings from " fasttext-file-path) + (->> (io/reader fasttext-file-path) + line-seq + remove-fasttext-metadata + read-text-embedding-pairs + (into {}))) (defn clean-str [s] (-> s @@ -188,9 +203,11 @@ sentences-padded (pad-sentences sentences) vocab (build-vocab sentences-padded) vocab-embeddings (case pretrained-embedding - :glove (->> (load-glove (glove-file-path embedding-size)) + :glove (->> (load-glove! (glove-file-path embedding-size)) (build-vocab-embeddings vocab embedding-size)) - :word2vec (->> (load-word2vec-model w2v-file-path embedding-size {:vocab vocab}) + :fasttext (->> (load-fasttext! fasttext-file-path) + (build-vocab-embeddings vocab embedding-size)) + :word2vec (->> (load-word2vec-model! w2v-file-path embedding-size {:vocab vocab}) (:word2vec) (build-vocab-embeddings vocab embedding-size)) vocab)