Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Clojure] Add fastText example #15340

Merged
merged 11 commits into from
Jun 30, 2019
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ You also must download the glove word embeddings. The suggested one to use is th
## Usage

You can run through the repl with
`(train-convnet {:embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})`

`(train-convnet {:devs [(context/default-context)] :embedding-size 50 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :glove})`
or
`JVM_OPTS="-Xmx1g" lein run` (cpu)

Expand All @@ -49,6 +48,21 @@ and then run
- `lein uberjar`
- `java -Xms1024m -Xmx2048m -jar target/cnn-text-classification-0.1.0-SNAPSHOT-standalone.jar`

## Usage with fastText

Using fastText instead of glove is fairly straightforward, as the pretrained embedding format is very similar.

Download the 'Simple English' pretrained wiki word vectors (text) from the fastText
[site](https://fasttext.cc/docs/en/pretrained-vectors.html) and place them in the
`data/fasttext` directory. Alternatively just run `./get_fasttext_data.sh`.

Then you can run training on a subset of examples through the repl using:
```
(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :fasttext})
```

Expect a validation accuracy of `~0.67` with the above parameters.

## Usage with word2vec

You can also use word2vec embeddings in order to train the text classification model.
Expand All @@ -58,15 +72,15 @@ you'll need to unzip them and place them in the `contrib/clojure-package/data` d

Then you can run training on a subset of examples through the repl using:
```
(train-convnet {:embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec})
(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 100 :num-epoch 10 :max-examples 1000 :pretrained-embedding :word2vec})
```
Note that loading word2vec embeddings consumes memory and takes some time.

You can also train them using `JVM_OPTS="-Xmx8g" lein run` once you've modified
the parameters to `train-convnet` (see above) in `src/cnn_text_classification/classifier.clj`.
In order to run training with word2vec on the complete data set, you will need to run:
```
(train-convnet {:embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec})
(train-convnet {:devs [(context/default-context)] :embedding-size 300 :batch-size 100 :test-size 1000 :num-epoch 10 :pretrained-embedding :word2vec})
```
You should be able to achieve an accuracy of `~0.78` using the parameters above.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set -evx

mkdir -p data/fasttext
cd data/fasttext
wget https://dl.fbaipublicfiles.com/fasttext/vectors-wiki/wiki.simple.vec
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
[embedding-size]
(format "data/glove/glove.6B.%dd.txt" embedding-size))

(def fasttext-file-path "data/fasttext/wiki.simple.vec")

(defn r-string
"Reads a string from the given DataInputStream `dis` until a space or newline is reached."
[dis]
Expand Down Expand Up @@ -62,7 +64,7 @@
vect (mapv (fn [_] (read-float dis)) (range embedding-size))]
(cons [word vect] (lazy-seq (load-w2v-vectors dis embedding-size (dec num-vectors)))))))

(defn load-word2vec-model
(defn load-word2vec-model!
"Loads the word2vec model stored in a binary format from the given `path`.
By default only the first 100 embeddings are loaded."
([path embedding-size opts]
Expand All @@ -75,8 +77,8 @@
_ (println "Processing with " {:dim dim :word-size word-size} " loading max vectors " max-vectors)
_ (if (not= embedding-size dim)
(throw (ex-info "Mismatch in embedding size"
{:input-embedding-size embedding-size
:word2vec-embedding-size dim})))
{:input-embedding-size embedding-size
:word2vec-embedding-size dim})))
vectors (load-w2v-vectors dis dim max-vectors)
word2vec (if vocab
(->> vectors
Expand All @@ -88,17 +90,30 @@
(println "Finished")
{:num-embed dim :word2vec word2vec})))
([path embedding-size]
(load-word2vec-model path embedding-size {:max-vectors 100})))
(load-word2vec-model! path embedding-size {:max-vectors 100})))

(defn read-text-embedding-pairs [rdr]
(for [^String line (line-seq rdr)
(defn read-text-embedding-pairs [pairs]
(for [^String line pairs
:let [fields (.split line " ")]]
[(aget fields 0)
(mapv #(Float/parseFloat ^String %) (rest fields))]))

(defn load-glove [glove-file-path]
(defn load-glove! [glove-file-path]
(println "Loading the glove pre-trained word embeddings from " glove-file-path)
(into {} (read-text-embedding-pairs (io/reader glove-file-path))))
(->> (io/reader glove-file-path)
line-seq
read-text-embedding-pairs
(into {})))

(def remove-fasttext-metadata rest)

(defn load-fasttext! [fasttext-file-path]
(println "Loading the fastText pre-trained word embeddings from " fasttext-file-path)
(->> (io/reader fasttext-file-path)
line-seq
remove-fasttext-metadata
read-text-embedding-pairs
(into {})))

(defn clean-str [s]
(-> s
Expand Down Expand Up @@ -188,9 +203,11 @@
sentences-padded (pad-sentences sentences)
vocab (build-vocab sentences-padded)
vocab-embeddings (case pretrained-embedding
:glove (->> (load-glove (glove-file-path embedding-size))
:glove (->> (load-glove! (glove-file-path embedding-size))
(build-vocab-embeddings vocab embedding-size))
:word2vec (->> (load-word2vec-model w2v-file-path embedding-size {:vocab vocab})
:fasttext (->> (load-fasttext! fasttext-file-path)
(build-vocab-embeddings vocab embedding-size))
:word2vec (->> (load-word2vec-model! w2v-file-path embedding-size {:vocab vocab})
(:word2vec)
(build-vocab-embeddings vocab embedding-size))
vocab)
Expand Down