Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update notebook example and exported markdown
Browse files Browse the repository at this point in the history
  • Loading branch information
gigasquid committed Apr 26, 2019
1 parent da4db39 commit 8e8b936
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 151 deletions.
125 changes: 52 additions & 73 deletions contrib/clojure-package/examples/bert/fine-tune-bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@
"source": [
"(defn fine-tune-model\n",
" \"msymbol: the pretrained network symbol\n",
" arg-params: the argument parameters of the pretrained model\n",
" num-classes: the number of classes for the fine-tune datasets\"\n",
" num-classes: the number of classes for the fine-tune datasets\n",
" dropout: the dropout rate\"\n",
" [msymbol {:keys [num-classes dropout]}]\n",
" (as-> msymbol data\n",
" (sym/dropout {:data data :p dropout})\n",
Expand Down Expand Up @@ -287,7 +287,6 @@
}
],
"source": [
"\n",
"(defn pre-processing\n",
" \"Preprocesses the sentences in the format that BERT is expecting\"\n",
" [ctx idx->token token->idx train-item]\n",
Expand Down Expand Up @@ -348,7 +347,7 @@
{
"data": {
"text/plain": [
"#object[org.apache.mxnet.io.NDArrayIter 0x4195d68 \"non-empty iterator\"]"
"#object[org.apache.mxnet.io.NDArrayIter 0x34050a17 \"non-empty iterator\"]"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -418,18 +417,63 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Speedometer: epoch 0 count 0 metric [accuracy 0.6875]\n",
"Speedometer: epoch 0 count 1 metric [accuracy 0.5625]\n",
"Speedometer: epoch 0 count 2 metric [accuracy 0.5729167]\n",
"Speedometer: epoch 0 count 3 metric [accuracy 0.5390625]\n",
"Speedometer: epoch 0 count 4 metric [accuracy 0.5]\n",
"Speedometer: epoch 0 count 5 metric [accuracy 0.5364583]\n",
"Speedometer: epoch 0 count 6 metric [accuracy 0.54910713]\n",
"Speedometer: epoch 0 count 7 metric [accuracy 0.56640625]\n",
"Speedometer: epoch 0 count 8 metric [accuracy 0.5763889]\n",
"Speedometer: epoch 0 count 9 metric [accuracy 0.565625]\n",
"Speedometer: epoch 0 count 10 metric [accuracy 0.56534094]\n",
"Speedometer: epoch 0 count 11 metric [accuracy 0.5729167]\n",
"Speedometer: epoch 0 count 12 metric [accuracy 0.5769231]\n",
"Speedometer: epoch 1 count 0 metric [accuracy 0.625]\n",
"Speedometer: epoch 1 count 1 metric [accuracy 0.65625]\n",
"Speedometer: epoch 1 count 2 metric [accuracy 0.6354167]\n",
"Speedometer: epoch 1 count 3 metric [accuracy 0.6484375]\n",
"Speedometer: epoch 1 count 4 metric [accuracy 0.6375]\n",
"Speedometer: epoch 1 count 5 metric [accuracy 0.625]\n",
"Speedometer: epoch 1 count 6 metric [accuracy 0.63839287]\n",
"Speedometer: epoch 1 count 7 metric [accuracy 0.65234375]\n",
"Speedometer: epoch 1 count 8 metric [accuracy 0.6666667]\n",
"Speedometer: epoch 1 count 9 metric [accuracy 0.653125]\n",
"Speedometer: epoch 1 count 10 metric [accuracy 0.64772725]\n",
"Speedometer: epoch 1 count 11 metric [accuracy 0.6536458]\n",
"Speedometer: epoch 1 count 12 metric [accuracy 0.65384614]\n",
"Speedometer: epoch 2 count 0 metric [accuracy 0.78125]\n",
"Speedometer: epoch 2 count 1 metric [accuracy 0.65625]\n",
"Speedometer: epoch 2 count 2 metric [accuracy 0.65625]\n",
"Speedometer: epoch 2 count 3 metric [accuracy 0.6875]\n",
"Speedometer: epoch 2 count 4 metric [accuracy 0.69375]\n",
"Speedometer: epoch 2 count 5 metric [accuracy 0.703125]\n",
"Speedometer: epoch 2 count 6 metric [accuracy 0.6964286]\n",
"Speedometer: epoch 2 count 7 metric [accuracy 0.69921875]\n",
"Speedometer: epoch 2 count 8 metric [accuracy 0.7013889]\n",
"Speedometer: epoch 2 count 9 metric [accuracy 0.690625]\n",
"Speedometer: epoch 2 count 10 metric [accuracy 0.69034094]\n",
"Speedometer: epoch 2 count 11 metric [accuracy 0.6953125]\n",
"Speedometer: epoch 2 count 12 metric [accuracy 0.7019231]\n"
]
},
{
"data": {
"text/plain": [
"#object[org.apache.mxnet.module.Module 0x3dbc97e6 \"org.apache.mxnet.module.Module@3dbc97e6\"]"
"#object[org.apache.mxnet.module.Module 0x35e65d46 \"org.apache.mxnet.module.Module@35e65d46\"]"
]
},
"execution_count": 9,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -445,72 +489,7 @@
" :arg-params (m/arg-params bert-base)\n",
" :aux-params (m/aux-params bert-base)\n",
" :optimizer (optimizer/adam {:learning-rate 5e-6 :episilon 1e-9})\n",
" :batch-end-callback (callback/speedometer batch-size 1)})})\n",
";;; Note you can check your `lein jupyter notebook` terminal to see progress in the training \n",
";;; example \n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [1]\tSpeed: 0.76 samples/sec\tTrain-accuracy=0.562500\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [2]\tSpeed: 0.86 samples/sec\tTrain-accuracy=0.572917\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [3]\tSpeed: 0.97 samples/sec\tTrain-accuracy=0.539063\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [4]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.500000\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [5]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.536458\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [6]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.549107\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [7]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.566406\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [8]\tSpeed: 1.07 samples/sec\tTrain-accuracy=0.576389\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [9]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.565625\n",
";; WARN org.apache.mxnet.WarnIfNotDisposed: LEAK: [one-time warning] An instance of org.apache.mxnet.Symbol was not disposed. Set property mxnet.traceLeakedObjects to true to enable tracing\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [10]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.565341\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [11]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.572917\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[0] Batch [12]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.576923\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[0] Train-accuracy=0.5769231\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[0] Time cost=407219\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [1]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.656250\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [2]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.635417\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [3]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.648438\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [4]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.637500\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [5]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.625000\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [6]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.638393\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [7]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.652344\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [8]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.666667\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [9]\tSpeed: 0.97 samples/sec\tTrain-accuracy=0.653125\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [10]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.647727\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [11]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.653646\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[1] Batch [12]\tSpeed: 1.04 samples/sec\tTrain-accuracy=0.653846\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[1] Train-accuracy=0.65384614\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[1] Time cost=404094\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [1]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.656250\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [2]\tSpeed: 1.06 samples/sec\tTrain-accuracy=0.656250\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [3]\tSpeed: 1.06 samples/sec\tTrain-accuracy=0.687500\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [4]\tSpeed: 1.07 samples/sec\tTrain-accuracy=0.693750\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [5]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.703125\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [6]\tSpeed: 1.07 samples/sec\tTrain-accuracy=0.696429\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [7]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.699219\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [8]\tSpeed: 1.05 samples/sec\tTrain-accuracy=0.701389\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [9]\tSpeed: 1.03 samples/sec\tTrain-accuracy=0.690625\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [10]\tSpeed: 0.99 samples/sec\tTrain-accuracy=0.690341\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [11]\tSpeed: 0.90 samples/sec\tTrain-accuracy=0.695313\n",
";; INFO org.apache.mxnet.Callback$Speedometer: Epoch[2] Batch [12]\tSpeed: 0.82 samples/sec\tTrain-accuracy=0.701923\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[2] Train-accuracy=0.7019231\n",
";; INFO org.apache.mxnet.module.BaseModule: Epoch[2] Time cost=411626\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"High level predict score is [accuracy 0.7692308]\n"
]
}
],
"source": [
"(let [score (m/score fine-tune-model {:eval-data train-data :eval-metric (eval-metric/accuracy)})]\n",
" (println \"High level predict score is \" score))"
" :batch-end-callback (callback/speedometer batch-size 1)})})\n"
]
}
],
Expand Down
Loading

0 comments on commit 8e8b936

Please sign in to comment.