From 0d51350ab9fb1cd1fb19ede77fd733329a8b5ef7 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Sat, 29 Dec 2018 13:36:14 -0800 Subject: [PATCH 1/8] Clojure example for fixed label-width captcha recognition --- .../examples/captcha/README.md | 5 + .../examples/captcha/captcha_example.png | Bin 0 -> 9762 bytes .../examples/captcha/get_data.sh | 32 ++++ .../examples/captcha/project.clj | 24 +++ .../examples/captcha/src/captcha/example.clj | 172 ++++++++++++++++++ .../captcha/test/captcha/example_test.clj | 7 + 6 files changed, 240 insertions(+) create mode 100644 contrib/clojure-package/examples/captcha/README.md create mode 100644 contrib/clojure-package/examples/captcha/captcha_example.png create mode 100755 contrib/clojure-package/examples/captcha/get_data.sh create mode 100644 contrib/clojure-package/examples/captcha/project.clj create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/example.clj create mode 100644 contrib/clojure-package/examples/captcha/test/captcha/example_test.clj diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md new file mode 100644 index 000000000000..cc97442f6207 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/README.md @@ -0,0 +1,5 @@ +This is the R version of [captcha recognition](http://blog.xlvector.net/2016-05/mxnet-ocr-cnn/) example by xlvector and it can be used as an example of multi-label training. For a captcha below, we consider it as an image with 4 labels and train a CNN over the data set. + +![](captcha_example.png) + +You can download the images and `.rec` files from [here](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). Since each image has 4 labels, please remember to use `label_width=4` when generating the `.rec` files. diff --git a/contrib/clojure-package/examples/captcha/captcha_example.png b/contrib/clojure-package/examples/captcha/captcha_example.png new file mode 100644 index 0000000000000000000000000000000000000000..09b84f7190fab4cb391f8a3927b10f44d609aaaf GIT binary patch literal 9762 zcmV+-Cf(VIP)002A)0ssI2ZV4tr001TFNkl$qd)Kg=JPmwuTSN6f=S^pBb7Zny2R zWJ{E6%91IIGHG!GH-H2PR$;5Ed-E>2oOvh`g0KQSx5g`fB>Z>tnC%C^w(T2xTUdOfJL=Hzse)#S6+cuN-_SaD$Q` zA;ZuM0bqH|Pj)(6f!O+^Ns+YV-1}EPY%k5B%eph^&%>uOK4VP(7tT2VoNb-!|Myh$ z`NwHCQL>fJ{p!~H0eMN0Jmw&m5$EwAzdmzo@pjL~EMt|R{BIY29r|I7RgUjJb}PSq zs`)%XefdQ_sp)JH+aXxV4<|pFOz%jMG0v;DmM96qqx706+5ctj`jDAaN_RUQKlF1` zDAWg|{YtA-%@-O${jtPX%a3NxwDW~PyNQx<7MJ|eBe%AUyeMYm>|JVEpSUj)mT7BR zy}25zOioVMYxT3Ob3YqU#=(%y7D%LG5V>u@t4QfaQdLZd6hHWuam9cJ5E1i=jC9r?K{(1A6`V%3T0XBo?Z98iNYZw_IAmok{vkWcTr2SI;xkM%XX+};6 zQ1A2o9k2^qA-I>tiIOeZK7Hqmq>MWv5P_?mY8tudZnwXco6xGIofOIE7r(f1@8%M8 zzFN89SLrSi1Jwu9uJOIs6C&XLpBoX7W)_Qm8%x9D_vOcf?P3xMEKs)`=UuJms$P7T%uhhQl3 zxA%_XLiW;L%e0M;hzJC6p2T_bM6|RaW3<*hMd=uu<}&aGT?Q96W^JI7?c+_Ea)GM~ldjl!7U9Ct7}f>#r$gOfVt}dC>3&e>nPsA#WgbLxya-{jJ@Elv0V7Grid*?8G3^ zX8_0mz!Cre-v6h-LV?O`-n$;}9US}7!MC^GsR3iaXt3NOAp$~N#7U7n5iK2f{@$U{ zvEk8jwcP9VF45JcqQf)diGMu7aUQqgcAG7v zEC+xR4+q`qKOOp6Giq#Ide&m)CKP8$H||}YzH(k&PACV#34jb)WxxUw5P-E71#r(H z4ia)x%sm@E(~@(DxRw1Z$r4$}Bu&6tTUc}*1`Irg9Pclm$3x%+#}6D!lB80p=m9g< z9LrUn7x~20)ob}=SHw9gX>$-v40=m>oMcyF0Y{C6YSLppZ%u0};As zTjwT|J3rk0y+&An=wOmeelhW7O1aJyN{qOt2f&G3nWxE<*oIx#<(&7a3OFx^B>>2a zyj~kRQJiSY`7@J$>F{1^v)n1HASLufenVVl?>zUmBb-OD8MlQETQ7=Yp|{x6aW9RT z0zV5*jvgQM2A&kVN-3GknWfn&dAp4Z0vrG&Ly_JV-1h!$$SAF8S zp67Z4wl)wB40;1!-8}zQ{CCC4I$#{?X7OCUIMupci7Njvf|ESCH8a^=?1D#7K#9UK ziz4TV$!=cc*QRfLHhDG=3j}T%fgw`X%iK5jzY&o?GB{F=s{M_s6L*|FVo7yAxq2qc zBs84FrtN(3W@i!rA`zCn4F^UH3eQjmXb}urG6bA}dMfVbNt`9aPr_MR7mTq=P!_=V zSH~86i`R?E%+bh^mD!T*NtI@K=KKEB>K*_#B3Nd!-JAdV-lddhE;u3gcb9gJ?H+xC z%PrCJ=KQUeH=jbr1S3Nx*S>P=;o$z-ShLyy07p3eheIzswC7-})4t+eP17{V3JOqC z`r-7MJ?)$C?SFG~f|oh7N-%;6+R z6>6q4b66jG+P$xD24{Kp#l)9gkr*C z%}lY7d0M=B;8ZoLI>K3I7XTc`aU7>ss{ZigTf0g-!`QPu9j8evX}6O0=7D8o$T;)3 z8?x|F?Z7MJF9vP^0Nmrqo;thz!BmY6km}oyl<91#=w&8*dj0m}!Rr&8=9xA+ zqXGmg`{l!9hZ>dolQsD+EL}9r7;Y_qAnqotDTek`Muy9cjojGiU@a0MA);`E5W;bsb)S4R#Cb9r?d}zc1polo@&n%3 z3-xE;IrLV^+sxz%{jgE3|MaCF{p{Hv4T%BYdZ1|&HCw3S{Edr6QABk9)$k zxdlSx%nAa8M9Kco6P(~$1j0YH_jzBrXbBjQz=YB-C%$}21jB=f;rav{v>oi)KBrG% ziIcTo-9B&oxe9{FjfN}D@y7lg)#38PY}&{Vo3+NCfsq#~&(w=r#G{hCp*I}oi32Ro z6M@djBPT1RvL{@o0kj3gl(LqbySX&^o5>H_>D*K0*mr!L>xf582xN#Hkt6VJM~b$N zek6GPJ=HQ}1F)GzU|j!@DP+9Nim0Rgtf)WKIV}%e0#tw`P zjBQ;{L{8WY8~bWwFYJFV+SpUFU}w~MHyJTJPTt<7upF$QY4wkN_iCd(bG=9=c>3kdnMJ9lts$`u&lz+ z4>u=xqtxnF(0_Oxf|XXrYGX7YlCdf+))L&yllk8KpDuoWE54PnEJ>wj+}n$HmU^A< zz4*o^1Y2u+DoIS5P>u|3VN|C8&(6=!54crFIIimk-ui#OaxDuXs-m*_43kWLaqXPU zB{*mj-4AXP01=S_AYiU=bEh#-|M}>9Pc{8ss3J~dD-FvKLm&b&CU*;CjJ1}r^%eEn z)aoNh^ZF9J*j+--0D-JU!PXKS7s+q0pPpX0-Gwd%Kt8D=)v3OiTpgPqW5^;e^u2Xl z_@qb{pX!qh%&HAsD`$T1~W+ zN^V!0*J-*aiYSPRq5uE{Mr+D85|y-Mdp2uzBPnXG@P!k(PQ@*1HZ7^&Uq1cim2)yr z*S2h}?G?!vv*!|%A|hE*z1*`-Ev)|e?%B^5zC7Q#^skqHbA9pFY_ApP@roYrujN)Y zw>;;Ef`F7T#DO70iwqb5l)UJTgRhjG5)lCcAboZBD>8PAZLUS4y2hmkDPFoC}(-r&3Y|0pYj}C#vMRJhxlH9 zcQveRBEBnAs{ruv^qKxI+cFn6k4zx~F^BaTmYHmI+rOLmXpRks7Gs87w3;|M`b@N`F$G^J% zyMH?L)8+%!x7}K`yvQe}uIE@Vz<~MB?R_?I?r#|c1jg80gLRQ{K)#W?_)x?}(l6GO zHv84h_xE_a-rW6~b`#aku>$-bJM$b=PfsFc&TAvS@$xkl)>1KQ@3l;YS5P9U1w%oWjH&&zZh&Z6n7fttr z$&ZiDAN<~t*Om#ETFzw)bNNDQ@5kpDXN1TZGljO9u>!d(fT0Lo4Ypm({p#v(-`V|x zX4EKoQ6t!1ty-KViIUP}$gv7#(6l^~k{?BGRCdY*8Q@H~&dQkO^473<$mRaSMe^zG zzruAqS?vcO?4QK+K?#z66 z>+KHpG`OnQToBm4^nL+7>5S6h)y6$uj~7o#5EMqv5v`7zl`fNB{`Zq^)A^ zR(I-RaAS*J9wkP9`>grqv)~|Wxv2bKZwYi-3VT{#l z_3t14UVPz$dwPM4Mb3aBJYblJv@%2%fdN1yfPPODkzUwb5E6;H+-npD+ zSz0LT0JxZFdA@ULpWGL@;TCmToX6K@uE#22jD;eI+#vY2Q(p4IDvu%_Bql?SnaDqy z{ggW>b&3wpZIP)=X7ZtN#K;Mo!O(v_{PR1VnUAK=WO=4hBS+u>0E{!n84#JmWJ@_{ zs8!1^H=ZAk8e7r-+r_#y=Nl626Fcqigb}3>5cx zfx1K8OVbmzpgOn(f@LP#owv-Oi#}&_U0w#ZpCAWT3Nx+F+6fZD1ZV2jpM<-gFE&= zy^z^jXL7?m7%Tx15g4$AE-u}=EGcujvojn%)EIA+>TB)$&zC2@0I+B1AFJ$xAL#@ySANKWjV5gPj!w=^ zn!*?jh8r#zN%B4Kh%>%#@7``N-Z9u5EY(V0>FL#Bt#zf=cfvw16rSTVpDdApdvw66CS`a$G{p$O1`Eli$fd79-#u9Q-R)?kq#X^Jp-dFZ+RX^wRf z_(8o=fA5)hMsdWk4geISG#Ld&|AD?P+9=1Ddd>kXC%-aPcZXDi2o*dBkdxzbTypWw(3e=z<=iARp*hFM*rJNZnP$9~{NLAc4uk-!UF$8oJkhL&23 zS8iR6vxFhz$T{Oc0GfiH8)3MSn(F7?kz=#m*i{?xnNQYoh9VNk7;GVIiO766;DI0j zpL@G%JJ(W`s6>*KB$rF~$&%GJU&;aujyeYq?K~K6Q;|k~X?)j#llx9Y7#5w(0-*zQ ze6f>uGM!qo+j5E?4UrS}f0jR4pWu>T8t`jJ!UMuGL{O-Ekt`eq0#+5stSW~H{ouLr zlfLhfwh74_y~#L_3sW$c<+dOqofM#8ZG3yVwPn#X%Y(2)b3Xvbad*aPgLobwB zi4!YObxm9us6D;XoOSV#nk)ct9hVV0nZqpO4Bf~B&ZJ3CPoL@2xsj^mzy)xD0syL5 z0lG6{PB|jRjG*65j+X+o>7&2zL{o?w$ zZuYe5>(-)A&FPslu!#h_t~%?lY8 zz$g^Kf#LCGqoyyrH+xepIiF|=09IS8jVd${0Yhs{$bvVAUtg&h^aE`D{W^{_5LE}f zYDhj>N}-sH?|gjWPjUXl?v2<8{W;$1P4%hSxMT;w_xyTgs4`M5^*TM0000Cb4^B0n zpUdYjE?fbCz2!Y;Tj$;!c_Uo0u`^<+%@CL+BQ#q|md;K7T{4LrK|_$bz4@-!Scl+M z@$@81WG<6D?c2DlFUt>%qKyucCz(u(G}9>)qU4rF>N`U(Oq5KtY^QU7X?@vB`|KhS zIhyIj2v-;m$bc+~lm|RmsrCz{?zZk$E7c$fwtg--@Pgy}kKOqDREHhauWK>7op^ajj(>aex%>l+19ze<G&!_TjtjZLqg6lcix z500%ou)0d0_l3Bpdl?l53^UG)A+?#J*ynX4_XI|L`!F~|n02vTPUf4_HXhkW{N~azG`Z3rcHZX@vh9l3YTnPn= zzO3Ku-EL*`{qj(2tSE3OEgz~LaD|t4Enz!@92rYZ{`;>_mrq6r&B%`cAlJoWy10-o zCOT=rHZod-gb{-TkRx*hA~Jdc2PsyP3LEgcQcfNcQaROp{$AX2m6P&<2^K4PW9C+4 z-;nFN_cy`-5+zSLVZY1IXaCxY+nHaWLX#UO5I5#-TCj{G5Fmmh`o~bV;DIpUWey1- z_6tVv$lgQC0s2+d4t;ekGkIQqt-%ly1MFnob~fL}#naPg5|sb|AQF*FBFx;8VCQoO zPI{hS?=}AK)Q73bkh4Tf8>0_GbqpU$0Kw4Fq!iIRXwnp_lBEIrpe_Uzsh1EpH%hp~x6 z>&+8^02Hiela7j$JOO~BD6ULSEYM=E_50U9`oZw)M|K~qiP}SnQAKg}=GDTe{tG4s zxufi+X9(8i6247mAAM3ce|yV#$YP9~u*{=s5iQaZ39Hb?lwQ6wF&GSZYvM8&dTAo*D270cGkCz5dgq~DJ;>3A?Uuk zK0bW-FWoPC5EB>+RBH}hN+-HD1^^;(EO(6vrF(Mt=(GDyG|P>?B8l8EWMN2QYBJ;q z9EshhA}?}nRi<-kbCs&I_rBaSJW_+&<}a-%ipAbiUKE+l``!m3k|n8SE1g?Q;WE~8 z2`lLAV=JVt#6wQ_0{}7{5(BeIJ9iX$nC4ll)9Um(eT!qIvW@d(cCi&}i9UursPfuK zurto%N**ooCG7z9s0_igW5;XdZ>-66XBFphOU_B1-p@B8NS(H``INGY3s;T}9$p*D zey#CRYoRTW0|zoxxcYqOBJ(*|bVxW>_?dIS866xN7_KyG!8hdtC9m}Q{+Ivl>)%K> zMUK^Swd|Dob-P5Mj1r#1gwYPgES>GNxMG1HF5jZER`Lo}T%Nu>nckrg5bsx?8D*r+ zHZO==#U*T?&b8S10QiH{@$u@BV|$M-UFzIY(@NNa)Iza%r!#YRes=G!v0AC>dw!fH ziArX=v!C7js|pPo^3Zwlz{$W1CTAy!A#es-A&Dvu?C^$nRGQ1ZEbredG23fhnVE>0 z0OMhQ+7XKHK3Bhh8wg_kfq* zuM1eUMwfbvEj_p7cKZ$mSX*+!s2WWrcTVlxFog2>T&|1t30~=C0$hH5@YLbqgWZL= z$Ft02ZNBikn;#BAZF+I~mBXjXL3uG<`smspXP}k(d9IZJrvy>es{nvs4ZBMT5uiqf zhzwVnlFR$|D(FH-PPCx9{X1EFO`+!X!&7168gZ+OP zq95>@?Y7Qb`?KMOTriH+d6CJygi)h8G!WJrL4BFFG)v<=Y30^xV+FDP>eTL_uC5*P zcY_THUX?l~qNSqa-SL2kxq>(uV>(4VBk!e4#f9Wj9%{z9wG=341e+)KpD20J?BYF{ zr4l3r0Gxx5UFvDvFzlBXd6s3jJJWw^eO3sC2EkBZyj?-Ft>^2dApm%5*Bfg$VAOE= z#udQ@LDBbxNLyC$dR{qL-9K8{v$s4Nh9LkXN?u&J(%0}S(-m)W;0G97$NtA?i z?riJa9)DMrmvg0*uo^WF9d5xX+ECTjtL7?j(gwg8_yS_=8k==h$0= zQO$F`&|EKX>^shrB$u&GAYuSO1Ugkt4{A~6nbBj-q2@|saNP>>Rd?;I+@#_>{?U+N{+ksf+_CF>+c&Yqscdr_gEh3i*0SFF=o&Ju&VBm>d zeEJC3P=eRK#F7{N_{b0c_2O^l^g>E0bAlt$oYXUfuu_S3n)Z7s@2WkU*YoqH9 z3;--p28Gs|k%NFPi8A*GqRNqxLqh{YQS^=0-FgN2@+S56p6^$9`TNJ-P?z+Sxsy_n znjASMN|Lp|k3TNDg(->|eHSy9c&RA6=x_iaZ2}&Qly~+|&HO{KjpCrPUqZxp_rLjn zzWVo$D%=`lf&*trK!(VehzHGJXt>l^KBXV|Vc_~6IiwA`0ubXoiFtCb+u{xffM`WE ztIXu*SI*AiqK+&wgMg^9B!ZHNBL5+mw1$4@yIuf6rixIM2HeW=%Atch4h)2~wP#tk zs4Z6sy2st0oOt)f;;j!RPbVs7E=SH&lO@(_RF>(?DsquIoH-m2p|HO726*+*?(y(H z1;GGN@}l9WF%s@9x=IzsY9ei25WetSatHm|+k3vh>h+N!YlzJKq^m-6CBEKRs$ZGD z{PLlfBR^6~C7G1DoLRivmh+OO=n7;=1U_?{eCUNkC)YiR`t5;*uOE12>B9hy9(s<-C+qmfB1jiHwQR0dN3-EmsO_VmVfH_|C~U z8gG$hJ>IkSpY-8 zF*o^QrhjwccO#V@I|g@UCRbWto4GMLKh@2;OF_(pAa?&9g~oes&xxV(Q1DQK{nAh{ zu*(@PL%4frWTdjQ7S`52wyguy>poL*qxTNH)wlEiaPu^26YC5aLym|D7D!u56j;Bi zqmzweJIceAp#0Fjo*uSOuwQbR^Sz_5eR=b%13UKzUg&}d{2=n8wcBc5XKckA` z@Sa2K-VyLP*CBi}efRz~#&MDS)4`v^9w9@*@A!`9FVo?{CvEi;=irKVpe25s}c$n`9d_#f?c^`^*Y|Gm9i?g7YPSalYU^GE^EAHhq<^W()Z(2lZ+tZg(_t2% zk3r$p+{#f$y_L~Ntiay7)o^HRe4=x$Qk3E)X-dQc2L%uSS!^hD)HL8tAHtrrD zs|y#ugKs`7iXzUFd%f0gu70o}Iw@p8E6S@0?pHM`!Z}lg^1-3;AHMKry)yJA?1A8zY{aSf5s{jB107*qoM6N<$f)Wn*ga7~l literal 0 HcmV?d00001 diff --git a/contrib/clojure-package/examples/captcha/get_data.sh b/contrib/clojure-package/examples/captcha/get_data.sh new file mode 100755 index 000000000000..baa7f9eb818f --- /dev/null +++ b/contrib/clojure-package/examples/captcha/get_data.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +EXAMPLE_ROOT=$(cd "$(dirname $0)"; pwd) + +data_path=$EXAMPLE_ROOT + +if [ ! -f "$data_path/captcha_example.zip" ]; then + wget https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip -P $data_path +fi + +if [ ! -f "$data_path/captcha_example/captcha_train.rec" ]; then + unzip $data_path/captcha_example.zip -d $data_path +fi diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj new file mode 100644 index 000000000000..98d836a4f76a --- /dev/null +++ b/contrib/clojure-package/examples/captcha/project.clj @@ -0,0 +1,24 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject captcha "0.1.0-SNAPSHOT" + :description "Captcha recognition via multi-label classification" + :plugins [[lein-cljfmt "0.5.7"]] + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] + :main ^:skip-aot captcha.example + :profiles {:uberjar {:aot :all}}) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/example.clj b/contrib/clojure-package/examples/captcha/src/captcha/example.clj new file mode 100644 index 000000000000..07eb9d2005b3 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/example.clj @@ -0,0 +1,172 @@ +(ns captcha.example + (:require [clojure.java.io :as io] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.initializer :as initializer] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym]) + (:gen-class)) + +(def batch-size 8) +(def data-shape [3 30 80]) +(def label-width 4) + +(defonce train-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec" + :path-imglist "captcha_example/captcha_train.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape + ;:mean-img "mean.bin" + :mean-r 127 + :mean-g 127 + :mean-b 127 + :mean-a 127 + :scale (/ 1.0 128) + })) + +(defonce eval-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec" + :path-imglist "captcha_example/captcha_test.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape + ;:mean-img "mean.bin" + :mean-r 127 + :mean-g 127 + :mean-b 127 + :mean-a 127 + :scale (/ 1.0 128) + })) + +(defn multi-label-accuracy + [label pred] + (let [[nr nc] (ndarray/shape-vec label) + pred-label (-> pred + (ndarray/argmax 1) + (ndarray/reshape [nr nc])) + digit-matches (ndarray/equal label pred-label) + ; captcha-matches (ndarray/equal (ndarray/sum digit-matches 1) 4) + ; num-complete-matches (ndarray/sum captcha-matches) + num-digit-matches (ndarray/sum digit-matches) + [total] (ndarray/->vec num-digit-matches) + ] + ; (println "Fraction:" (float (/ total nr nc))) + (float (/ total nr nc)))) + +(defn get-data-symbol + [] + (let [data (sym/variable "data") + + conv1 (sym/convolution {:data data :kernel [5 5] :num-filter 32}) + pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]}) + relu1 (sym/activation {:data pool1 :act-type "relu"}) + + conv2 (sym/convolution {:data relu1 :kernel [5 5] :num-filter 32}) + pool2 (sym/pooling {:data conv2 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu2 (sym/activation {:data pool2 :act-type "relu"}) + + conv3 (sym/convolution {:data relu2 :kernel [3 3] :num-filter 32}) + pool3 (sym/pooling {:data conv3 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu3 (sym/activation {:data pool3 :act-type "relu"}) + + conv4 (sym/convolution {:data relu3 :kernel [3 3] :num-filter 32}) + pool4 (sym/pooling {:data conv4 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu4 (sym/activation {:data pool4 :act-type "relu"}) + + flattened (sym/flatten {:data relu4}) + dropped (sym/dropout {:data flattened :p 0.1}) + fc1 (sym/fully-connected {:data dropped :num-hidden 256}) + fc21 (sym/fully-connected {:data fc1 :num-hidden 10}) + fc22 (sym/fully-connected {:data fc1 :num-hidden 10}) + fc23 (sym/fully-connected {:data fc1 :num-hidden 10}) + fc24 (sym/fully-connected {:data fc1 :num-hidden 10})] + (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0}))) + +(defn get-label-symbol + [] + (as-> (sym/variable "label") label + (sym/transpose {:data label}) + (sym/reshape {:data label :shape [-1]}))) + +(defn create-captcha-net + [] + (let [scores (get-data-symbol) + labels (get-label-symbol)] + (sym/softmax-output {:data scores :label labels}))) + +(comment + (def batch (mx-io/next train-data)) + (mx-io/batch-index batch) + (-> batch mx-io/batch-label first ndarray/->vec) + (-> batch mx-io/batch-data first ndarray/->vec) + (def _mod (m/module (create-captcha-net) + {:data-names ["data"] :label-names ["label"]})) + (m/bind _mod {:data-shapes (mx-io/provide-data-desc train-data) + :label-shapes (mx-io/provide-label-desc train-data)}) + (m/init-params _mod + {:initializer (initializer/uniform 0.01) + :force-init true}) + ; (m/init-params _mod) + (m/forward _mod batch) + (m/output-shapes _mod) + (m/outputs _mod) + (-> batch mx-io/batch-label first ndarray/->vec) + ; (-> _mod m/outputs first first (ndarray/> 0.5) (ndarray/reshape [0 -1 4]) (ndarray/argmax 1) ndarray/->vec) + (-> _mod m/outputs-merged first ndarray/->vec) + (m/backward _mod) + (def out-grads (m/grad-arrays _mod)) + (map #(-> % first ndarray/shape-vec) out-grads)) + +(comment + (def optimizer + (optimizer/sgd + {:learning-rate 0.0001 + :momentum 0.9 + :wd 0.00001 + :clip-gradient 10}))) +(def optimizer + (optimizer/adam + {:learning-rate 0.0002 + :wd 0.00001 + :clip-gradient 10})) + +(defn start + [devs] + (do + (println "Starting the captcha training ...") + (let [_mod (m/module + (create-captcha-net) + {:data-names ["data"] :label-names ["label"] + :contexts devs})] + (m/fit _mod {:train-data train-data + :eval-data eval-data + :num-epoch 20 + :fit-params (m/fit-params + {:kvstore "local" + :batch-end-callback + (callback/speedometer batch-size 50) + :initializer + (initializer/xavier {:factor-type "in" + :magnitude 2.34}) + ;(initializer/uniform 0.01) + :optimizer optimizer + :eval-metric (eval-metric/custom-metric + #(multi-label-accuracy %1 %2) + "accuracy") + })}) + (println "Finished the fit") + _mod))) + +(defn -main + [& args] + (let [[dev dev-num] args + num-devices (Integer/parseInt (or dev-num "1")) + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range num-devices)) + (mapv #(context/cpu %) (range num-devices)))] + (start devs))) diff --git a/contrib/clojure-package/examples/captcha/test/captcha/example_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/example_test.clj new file mode 100644 index 000000000000..0494d7d5689a --- /dev/null +++ b/contrib/clojure-package/examples/captcha/test/captcha/example_test.clj @@ -0,0 +1,7 @@ +(ns captcha.example-test + (:require [clojure.test :refer :all] + [captcha.example :refer :all])) + +(deftest a-test + (testing "FIXME, I fail." + (is (= 0 1)))) From 9a3f85dfabfce6e96bb4f4b2632318dded3cd0f8 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Sun, 30 Dec 2018 15:57:29 -0800 Subject: [PATCH 2/8] Update evaluation --- .../examples/captcha/src/captcha/example.clj | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/example.clj b/contrib/clojure-package/examples/captcha/src/captcha/example.clj index 07eb9d2005b3..7e661a86e5eb 100644 --- a/contrib/clojure-package/examples/captcha/src/captcha/example.clj +++ b/contrib/clojure-package/examples/captcha/src/captcha/example.clj @@ -22,6 +22,8 @@ :label-width label-width :data-shape data-shape ;:mean-img "mean.bin" + :shuffle true + :seed 42 :mean-r 127 :mean-g 127 :mean-b 127 @@ -46,16 +48,11 @@ (defn multi-label-accuracy [label pred] (let [[nr nc] (ndarray/shape-vec label) - pred-label (-> pred - (ndarray/argmax 1) - (ndarray/reshape [nr nc])) - digit-matches (ndarray/equal label pred-label) - ; captcha-matches (ndarray/equal (ndarray/sum digit-matches 1) 4) - ; num-complete-matches (ndarray/sum captcha-matches) - num-digit-matches (ndarray/sum digit-matches) - [total] (ndarray/->vec num-digit-matches) - ] - ; (println "Fraction:" (float (/ total nr nc))) + label-t (-> label ndarray/transpose (ndarray/reshape [-1])) + pred-label (ndarray/argmax pred 1) + [total] (-> (ndarray/equal label-t pred-label) + ndarray/sum + ndarray/->vec)] (float (/ total nr nc)))) (defn get-data-symbol @@ -79,8 +76,7 @@ relu4 (sym/activation {:data pool4 :act-type "relu"}) flattened (sym/flatten {:data relu4}) - dropped (sym/dropout {:data flattened :p 0.1}) - fc1 (sym/fully-connected {:data dropped :num-hidden 256}) + fc1 (sym/fully-connected {:data flattened :num-hidden 256}) fc21 (sym/fully-connected {:data fc1 :num-hidden 10}) fc22 (sym/fully-connected {:data fc1 :num-hidden 10}) fc23 (sym/fully-connected {:data fc1 :num-hidden 10}) From f9092f92b5c85534e830641c1e4f0541796fe5a5 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Mon, 31 Dec 2018 09:57:05 -0800 Subject: [PATCH 3/8] Better training and inference (w/ cleanup) --- .../examples/captcha/.gitignore | 2 + .../examples/captcha/project.clj | 8 +- .../captcha/src/captcha/infer_ocr.clj | 45 +++++++ .../captcha/{example.clj => train_ocr.clj} | 123 ++++++++---------- 4 files changed, 108 insertions(+), 70 deletions(-) create mode 100644 contrib/clojure-package/examples/captcha/.gitignore create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj rename contrib/clojure-package/examples/captcha/src/captcha/{example.clj => train_ocr.clj} (57%) diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore new file mode 100644 index 000000000000..2ff006761866 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/.gitignore @@ -0,0 +1,2 @@ +/.lein-* +/.nrepl-port diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj index 98d836a4f76a..fa37fecbe035 100644 --- a/contrib/clojure-package/examples/captcha/project.clj +++ b/contrib/clojure-package/examples/captcha/project.clj @@ -20,5 +20,9 @@ :plugins [[lein-cljfmt "0.5.7"]] :dependencies [[org.clojure/clojure "1.9.0"] [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] - :main ^:skip-aot captcha.example - :profiles {:uberjar {:aot :all}}) + :main ^:skip-aot captcha.train-ocr + :profiles {:train {:main captcha.train-ocr} + :infer {:main captcha.infer-ocr} + :uberjar {:aot :all}} + :aliases {"train" ["with-profile" "train" "run"] + "infer" ["with-profile" "infer" "run"]}) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj new file mode 100644 index 000000000000..2e1db681a46d --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj @@ -0,0 +1,45 @@ +(ns captcha.infer-ocr + (:require [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.infer :as infer] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(def batch-size 8) +(def channels 3) +(def height 30) +(def width 80) +(def label-width 4) +(def model-prefix "ocr") + +(defn create-predictor + [] + (let [data-desc {:name "data" + :shape [batch-size channels height width] + :layout layout/NCHW + :dtype dtype/FLOAT32} + label-desc {:name "label" + :shape [batch-size label-width] + :layout layout/NT + :dtype dtype/FLOAT32} + factory (infer/model-factory model-prefix + [data-desc label-desc])] + (infer/create-predictor factory))) + +(defn -main + [& args] + (let [[filename] args + image-fname (or filename "captcha_example.png") + image-ndarray (-> image-fname + infer/load-image-from-file + (infer/reshape-image width height) + (infer/buffered-image-to-pixels [channels height width]) + (ndarray/expand-dims 0)) + label-ndarray (ndarray/zeros [1 label-width]) + predictor (create-predictor) + predictions (-> (infer/predict-with-ndarray + predictor + [image-ndarray label-ndarray]) + first + (ndarray/argmax 1) + ndarray/->vec)] + (println "CAPTCHA output:" (apply str (mapv int predictions))))) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/example.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj similarity index 57% rename from contrib/clojure-package/examples/captcha/src/captcha/example.clj rename to contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj index 7e661a86e5eb..aa296d721ef7 100644 --- a/contrib/clojure-package/examples/captcha/src/captcha/example.clj +++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj @@ -1,4 +1,4 @@ -(ns captcha.example +(ns captcha.train-ocr (:require [clojure.java.io :as io] [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] @@ -21,45 +21,40 @@ :batch-size batch-size :label-width label-width :data-shape data-shape - ;:mean-img "mean.bin" :shuffle true - :seed 42 - :mean-r 127 - :mean-g 127 - :mean-b 127 - :mean-a 127 - :scale (/ 1.0 128) - })) + :seed 42})) (defonce eval-data (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec" :path-imglist "captcha_example/captcha_test.lst" :batch-size batch-size :label-width label-width - :data-shape data-shape - ;:mean-img "mean.bin" - :mean-r 127 - :mean-g 127 - :mean-b 127 - :mean-a 127 - :scale (/ 1.0 128) - })) + :data-shape data-shape})) (defn multi-label-accuracy [label pred] (let [[nr nc] (ndarray/shape-vec label) label-t (-> label ndarray/transpose (ndarray/reshape [-1])) pred-label (ndarray/argmax pred 1) - [total] (-> (ndarray/equal label-t pred-label) - ndarray/sum - ndarray/->vec)] - (float (/ total nr nc)))) + matches (ndarray/equal label-t pred-label) + [digit-matches] (-> matches + ndarray/sum + ndarray/->vec) + [complete-matches] (-> matches + (ndarray/reshape [nc nr]) + (ndarray/sum 0) + (ndarray/equal label-width) + ndarray/sum + ndarray/->vec)] + ; (float (/ digit-matches nr nc)) + (float (/ complete-matches nr)))) (defn get-data-symbol [] (let [data (sym/variable "data") + scaled (sym/div (sym/- data 127) 128) - conv1 (sym/convolution {:data data :kernel [5 5] :num-filter 32}) + conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32}) pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]}) relu1 (sym/activation {:data pool1 :act-type "relu"}) @@ -76,7 +71,8 @@ relu4 (sym/activation {:data pool4 :act-type "relu"}) flattened (sym/flatten {:data relu4}) - fc1 (sym/fully-connected {:data flattened :num-hidden 256}) + dropped (sym/dropout {:data flattened :p 0.25}) + fc1 (sym/fully-connected {:data dropped :num-hidden 256}) fc21 (sym/fully-connected {:data fc1 :num-hidden 10}) fc22 (sym/fully-connected {:data fc1 :num-hidden 10}) fc23 (sym/fully-connected {:data fc1 :num-hidden 10}) @@ -100,63 +96,53 @@ (mx-io/batch-index batch) (-> batch mx-io/batch-label first ndarray/->vec) (-> batch mx-io/batch-data first ndarray/->vec) - (def _mod (m/module (create-captcha-net) + (def model (m/module (create-captcha-net) {:data-names ["data"] :label-names ["label"]})) - (m/bind _mod {:data-shapes (mx-io/provide-data-desc train-data) + (m/bind model {:data-shapes (mx-io/provide-data-desc train-data) :label-shapes (mx-io/provide-label-desc train-data)}) - (m/init-params _mod + (m/init-params model {:initializer (initializer/uniform 0.01) :force-init true}) - ; (m/init-params _mod) - (m/forward _mod batch) - (m/output-shapes _mod) - (m/outputs _mod) + ; (m/init-params model) + (m/forward model batch) + (m/output-shapes model) + (m/outputs model) (-> batch mx-io/batch-label first ndarray/->vec) - ; (-> _mod m/outputs first first (ndarray/> 0.5) (ndarray/reshape [0 -1 4]) (ndarray/argmax 1) ndarray/->vec) - (-> _mod m/outputs-merged first ndarray/->vec) - (m/backward _mod) - (def out-grads (m/grad-arrays _mod)) + ; (-> model m/outputs first first (ndarray/> 0.5) (ndarray/reshape [0 -1 4]) (ndarray/argmax 1) ndarray/->vec) + (-> model m/outputs-merged first ndarray/->vec) + (m/backward model) + (def out-grads (m/grad-arrays model)) (map #(-> % first ndarray/shape-vec) out-grads)) -(comment - (def optimizer - (optimizer/sgd - {:learning-rate 0.0001 - :momentum 0.9 - :wd 0.00001 - :clip-gradient 10}))) (def optimizer (optimizer/adam {:learning-rate 0.0002 :wd 0.00001 :clip-gradient 10})) -(defn start +(defn train-ocr [devs] - (do - (println "Starting the captcha training ...") - (let [_mod (m/module - (create-captcha-net) - {:data-names ["data"] :label-names ["label"] - :contexts devs})] - (m/fit _mod {:train-data train-data - :eval-data eval-data - :num-epoch 20 - :fit-params (m/fit-params - {:kvstore "local" - :batch-end-callback - (callback/speedometer batch-size 50) - :initializer - (initializer/xavier {:factor-type "in" - :magnitude 2.34}) - ;(initializer/uniform 0.01) - :optimizer optimizer - :eval-metric (eval-metric/custom-metric - #(multi-label-accuracy %1 %2) - "accuracy") - })}) - (println "Finished the fit") - _mod))) + (println "Starting the captcha training ...") + (let [model (m/module + (create-captcha-net) + {:data-names ["data"] :label-names ["label"] + :contexts devs})] + (m/fit model {:train-data train-data + :eval-data eval-data + :num-epoch 10 + :fit-params (m/fit-params + {:kvstore "local" + :batch-end-callback + (callback/speedometer batch-size 100) + :initializer + (initializer/xavier {:factor-type "in" + :magnitude 2.34}) + :optimizer optimizer + :eval-metric (eval-metric/custom-metric + #(multi-label-accuracy %1 %2) + "accuracy")})}) + (println "Finished the fit") + model)) (defn -main [& args] @@ -164,5 +150,6 @@ num-devices (Integer/parseInt (or dev-num "1")) devs (if (= dev ":gpu") (mapv #(context/gpu %) (range num-devices)) - (mapv #(context/cpu %) (range num-devices)))] - (start devs))) + (mapv #(context/cpu %) (range num-devices))) + model (train-ocr devs)] + (m/save-checkpoint model {:prefix "ocr" :epoch 0}))) From b5d3caa77a558d4fb67b0184c26407fa20764793 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Mon, 31 Dec 2018 10:27:40 -0800 Subject: [PATCH 4/8] Captcha generation for testing --- .../examples/captcha/.gitignore | 1 + .../examples/captcha/gen_captcha.py | 23 +++++++++++++++++++ 2 files changed, 24 insertions(+) create mode 100755 contrib/clojure-package/examples/captcha/gen_captcha.py diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore index 2ff006761866..e1569bd89020 100644 --- a/contrib/clojure-package/examples/captcha/.gitignore +++ b/contrib/clojure-package/examples/captcha/.gitignore @@ -1,2 +1,3 @@ /.lein-* /.nrepl-port +images/* diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py new file mode 100755 index 000000000000..3a5b5141f99e --- /dev/null +++ b/contrib/clojure-package/examples/captcha/gen_captcha.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +from captcha.image import ImageCaptcha +import os +import random + +length = 4 +width = 160 +height = 60 +IMAGE_DIR = "images" + + +def random_text(): + return ''.join(str(random.randint(0, 9)) + for _ in range(length)) + + +if __name__ == '__main__': + image = ImageCaptcha(width=width, height=height) + captcha_text = random_text() + if not os.path.exists(IMAGE_DIR): + os.makedirs(IMAGE_DIR) + image.write(captcha_text, os.path.join(IMAGE_DIR, captcha_text + ".png")) From 6f0d22dc9db4e75bb7e456bef1259c6a09a00c08 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Tue, 1 Jan 2019 11:03:29 -0800 Subject: [PATCH 5/8] Make simple test work --- .../captcha/src/captcha/train_ocr.clj | 23 ++++++++++++++----- .../{example_test.clj => train_ocr_test.clj} | 4 ++-- 2 files changed, 19 insertions(+), 8 deletions(-) rename contrib/clojure-package/examples/captcha/test/captcha/{example_test.clj => train_ocr_test.clj} (58%) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj index aa296d721ef7..accc98904a51 100644 --- a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj +++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj @@ -1,5 +1,6 @@ (ns captcha.train-ocr (:require [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] [org.apache.clojure-mxnet.eval-metric :as eval-metric] @@ -15,12 +16,22 @@ (def data-shape [3 30 80]) (def label-width 4) +(when-not (.exists (io/file "captcha_example/captcha_train.lst")) + (sh "./get_data.sh")) + (defonce train-data (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec" :path-imglist "captcha_example/captcha_train.lst" :batch-size batch-size :label-width label-width :data-shape data-shape + ;:rotate 5 + ;:brightness 0.05 + ;:contrast 0.05 + ;:saturation 0.05 + ;:random-h 20 + ;:random-s 5 + ;:random-l 5 :shuffle true :seed 42})) @@ -71,12 +82,12 @@ relu4 (sym/activation {:data pool4 :act-type "relu"}) flattened (sym/flatten {:data relu4}) - dropped (sym/dropout {:data flattened :p 0.25}) - fc1 (sym/fully-connected {:data dropped :num-hidden 256}) - fc21 (sym/fully-connected {:data fc1 :num-hidden 10}) - fc22 (sym/fully-connected {:data fc1 :num-hidden 10}) - fc23 (sym/fully-connected {:data fc1 :num-hidden 10}) - fc24 (sym/fully-connected {:data fc1 :num-hidden 10})] + fc1 (sym/fully-connected {:data flattened :num-hidden 256}) + dropped (sym/dropout {:data fc1 :p 0.1}) + fc21 (sym/fully-connected {:data dropped :num-hidden 10}) + fc22 (sym/fully-connected {:data dropped :num-hidden 10}) + fc23 (sym/fully-connected {:data dropped :num-hidden 10}) + fc24 (sym/fully-connected {:data dropped :num-hidden 10})] (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0}))) (defn get-label-symbol diff --git a/contrib/clojure-package/examples/captcha/test/captcha/example_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj similarity index 58% rename from contrib/clojure-package/examples/captcha/test/captcha/example_test.clj rename to contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj index 0494d7d5689a..be839863db3c 100644 --- a/contrib/clojure-package/examples/captcha/test/captcha/example_test.clj +++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj @@ -1,6 +1,6 @@ -(ns captcha.example-test +(ns captcha.train-ocr-test (:require [clojure.test :refer :all] - [captcha.example :refer :all])) + [captcha.train-ocr :refer :all])) (deftest a-test (testing "FIXME, I fail." From 6dea7130465de099ea613dd9c564bfb274bcbe27 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Thu, 3 Jan 2019 21:58:36 -0800 Subject: [PATCH 6/8] Add test and update README --- .../examples/captcha/README.md | 61 ++++++++- .../examples/captcha/gen_captcha.py | 19 ++- .../captcha/src/captcha/infer_ocr.clj | 59 +++++---- .../captcha/src/captcha/train_ocr.clj | 84 ++++++------ .../captcha/test/captcha/train_ocr_test.clj | 120 +++++++++++++++++- 5 files changed, 265 insertions(+), 78 deletions(-) diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md index cc97442f6207..76a9438df52f 100644 --- a/contrib/clojure-package/examples/captcha/README.md +++ b/contrib/clojure-package/examples/captcha/README.md @@ -1,5 +1,62 @@ -This is the R version of [captcha recognition](http://blog.xlvector.net/2016-05/mxnet-ocr-cnn/) example by xlvector and it can be used as an example of multi-label training. For a captcha below, we consider it as an image with 4 labels and train a CNN over the data set. +# Captcha + +This is the clojure version of [captcha recognition](https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr) +example by xlvector and mirrors the R captcha example. It can be used as an +example of multi-label training. For a captcha below, we consider it as an +image with 4 labels and train a CNN over the data set. ![](captcha_example.png) -You can download the images and `.rec` files from [here](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). Since each image has 4 labels, please remember to use `label_width=4` when generating the `.rec` files. +## Installation + +Before you run this example, make sure that you have the clojure package +installed. In the main clojure package directory, do `lein install`. +Then you can run `lein install` in this directory. + +## Usage + +### Training + +First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). +The training can be started using: +``` +$ lein train [:cpu|:gpu] [num-devices] +``` +which downloads the training/evaluation data using the `get_data.sh` script +before starting training. + +I encountered some out-of-memory issues while training using :gpu on Ubuntu +linux (18.04). However, the command `lein train` (training on one CPU) worked +well for me. + +The training runs for 10 iterations by default and saves the model with the +prefix `ocr-`. I was able to achieve an exact match accuracy of ~0.954 and +~0.628 on training and validation data respectively. + +### Inference + +Once the model has been saved, it can be used for prediction. This can be done +by running: +``` +$ lein infer +INFO MXNetJVM: Try loading mxnet-scala from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-gpu from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-cpu from native path. +WARN MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path]. +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +INFO org.apache.mxnet.infer.Predictor: Latency increased due to batchSize mismatch 8 vs 1 +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +CAPTCHA output: 6643 +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/libmxnet.so +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865 +``` +The model is run on `captcha_example.png` by default. + +It can be run on other generated captcha images as well. The script +`gen_captcha.py` generates random captcha images for length 4. +Before running the python script, you will need to install the [captcha](https://pypi.org/project/captcha/) +library using `pip3 install --user captcha`. The captcha images are generated +in the `images/` folder and we can run the prediction using +`lein infer images/7534.png`. diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py index 3a5b5141f99e..43e0d26fb961 100755 --- a/contrib/clojure-package/examples/captcha/gen_captcha.py +++ b/contrib/clojure-package/examples/captcha/gen_captcha.py @@ -1,4 +1,21 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. from captcha.image import ImageCaptcha import os diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj index 2e1db681a46d..f6a648e9867b 100644 --- a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj +++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj @@ -1,27 +1,38 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + (ns captcha.infer-ocr - (:require [org.apache.clojure-mxnet.dtype :as dtype] + (:require [captcha.consts :refer :all] + [org.apache.clojure-mxnet.dtype :as dtype] [org.apache.clojure-mxnet.infer :as infer] [org.apache.clojure-mxnet.layout :as layout] [org.apache.clojure-mxnet.ndarray :as ndarray])) -(def batch-size 8) -(def channels 3) -(def height 30) -(def width 80) -(def label-width 4) -(def model-prefix "ocr") - (defn create-predictor [] - (let [data-desc {:name "data" - :shape [batch-size channels height width] - :layout layout/NCHW + (let [data-desc {:name "data" + :shape [batch-size channels height width] + :layout layout/NCHW :dtype dtype/FLOAT32} - label-desc {:name "label" - :shape [batch-size label-width] - :layout layout/NT - :dtype dtype/FLOAT32} - factory (infer/model-factory model-prefix + label-desc {:name "label" + :shape [batch-size label-width] + :layout layout/NT + :dtype dtype/FLOAT32} + factory (infer/model-factory model-prefix [data-desc label-desc])] (infer/create-predictor factory))) @@ -30,16 +41,16 @@ (let [[filename] args image-fname (or filename "captcha_example.png") image-ndarray (-> image-fname - infer/load-image-from-file - (infer/reshape-image width height) - (infer/buffered-image-to-pixels [channels height width]) + infer/load-image-from-file + (infer/reshape-image width height) + (infer/buffered-image-to-pixels [channels height width]) (ndarray/expand-dims 0)) label-ndarray (ndarray/zeros [1 label-width]) predictor (create-predictor) - predictions (-> (infer/predict-with-ndarray - predictor - [image-ndarray label-ndarray]) - first - (ndarray/argmax 1) + predictions (-> (infer/predict-with-ndarray + predictor + [image-ndarray label-ndarray]) + first + (ndarray/argmax 1) ndarray/->vec)] (println "CAPTCHA output:" (apply str (mapv int predictions))))) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj index accc98904a51..91ec2fff3af7 100644 --- a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj +++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj @@ -1,5 +1,23 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + (ns captcha.train-ocr - (:require [clojure.java.io :as io] + (:require [captcha.consts :refer :all] + [clojure.java.io :as io] [clojure.java.shell :refer [sh]] [org.apache.clojure-mxnet.callback :as callback] [org.apache.clojure-mxnet.context :as context] @@ -12,10 +30,6 @@ [org.apache.clojure-mxnet.symbol :as sym]) (:gen-class)) -(def batch-size 8) -(def data-shape [3 30 80]) -(def label-width 4) - (when-not (.exists (io/file "captcha_example/captcha_train.lst")) (sh "./get_data.sh")) @@ -25,13 +39,6 @@ :batch-size batch-size :label-width label-width :data-shape data-shape - ;:rotate 5 - ;:brightness 0.05 - ;:contrast 0.05 - ;:saturation 0.05 - ;:random-h 20 - ;:random-s 5 - ;:random-l 5 :shuffle true :seed 42})) @@ -42,10 +49,15 @@ :label-width label-width :data-shape data-shape})) -(defn multi-label-accuracy - [label pred] +(defn accuracy + [label pred & {:keys [by-character] + :or {by-character false} :as opts}] (let [[nr nc] (ndarray/shape-vec label) - label-t (-> label ndarray/transpose (ndarray/reshape [-1])) + pred-context (ndarray/context pred) + label-t (-> label + ndarray/transpose + (ndarray/reshape [-1]) + (ndarray/as-in-context pred-context)) pred-label (ndarray/argmax pred 1) matches (ndarray/equal label-t pred-label) [digit-matches] (-> matches @@ -57,12 +69,14 @@ (ndarray/equal label-width) ndarray/sum ndarray/->vec)] - ; (float (/ digit-matches nr nc)) - (float (/ complete-matches nr)))) + (if by-character + (float (/ digit-matches nr nc)) + (float (/ complete-matches nr))))) (defn get-data-symbol [] (let [data (sym/variable "data") + ;; normalize the input pixels scaled (sym/div (sym/- data 127) 128) conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32}) @@ -83,11 +97,10 @@ flattened (sym/flatten {:data relu4}) fc1 (sym/fully-connected {:data flattened :num-hidden 256}) - dropped (sym/dropout {:data fc1 :p 0.1}) - fc21 (sym/fully-connected {:data dropped :num-hidden 10}) - fc22 (sym/fully-connected {:data dropped :num-hidden 10}) - fc23 (sym/fully-connected {:data dropped :num-hidden 10}) - fc24 (sym/fully-connected {:data dropped :num-hidden 10})] + fc21 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc22 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc23 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc24 (sym/fully-connected {:data fc1 :num-hidden num-labels})] (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0}))) (defn get-label-symbol @@ -102,29 +115,6 @@ labels (get-label-symbol)] (sym/softmax-output {:data scores :label labels}))) -(comment - (def batch (mx-io/next train-data)) - (mx-io/batch-index batch) - (-> batch mx-io/batch-label first ndarray/->vec) - (-> batch mx-io/batch-data first ndarray/->vec) - (def model (m/module (create-captcha-net) - {:data-names ["data"] :label-names ["label"]})) - (m/bind model {:data-shapes (mx-io/provide-data-desc train-data) - :label-shapes (mx-io/provide-label-desc train-data)}) - (m/init-params model - {:initializer (initializer/uniform 0.01) - :force-init true}) - ; (m/init-params model) - (m/forward model batch) - (m/output-shapes model) - (m/outputs model) - (-> batch mx-io/batch-label first ndarray/->vec) - ; (-> model m/outputs first first (ndarray/> 0.5) (ndarray/reshape [0 -1 4]) (ndarray/argmax 1) ndarray/->vec) - (-> model m/outputs-merged first ndarray/->vec) - (m/backward model) - (def out-grads (m/grad-arrays model)) - (map #(-> % first ndarray/shape-vec) out-grads)) - (def optimizer (optimizer/adam {:learning-rate 0.0002 @@ -150,7 +140,7 @@ :magnitude 2.34}) :optimizer optimizer :eval-metric (eval-metric/custom-metric - #(multi-label-accuracy %1 %2) + #(accuracy %1 %2) "accuracy")})}) (println "Finished the fit") model)) @@ -163,4 +153,4 @@ (mapv #(context/gpu %) (range num-devices)) (mapv #(context/cpu %) (range num-devices))) model (train-ocr devs)] - (m/save-checkpoint model {:prefix "ocr" :epoch 0}))) + (m/save-checkpoint model {:prefix model-prefix :epoch 0}))) diff --git a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj index be839863db3c..ab785f7fedf2 100644 --- a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj +++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj @@ -1,7 +1,119 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + (ns captcha.train-ocr-test (:require [clojure.test :refer :all] - [captcha.train-ocr :refer :all])) + [captcha.consts :refer :all] + [captcha.train-ocr :refer :all] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as shape] + [org.apache.clojure-mxnet.util :as util])) + +(deftest test-consts + (is (= 8 batch-size)) + (is (= [3 30 80] data-shape)) + (is (= 4 label-width)) + (is (= 10 num-labels))) + +(deftest test-labeled-data + (let [train-batch (mx-io/next train-data) + eval-batch (mx-io/next eval-data) + allowed-labels (into #{} (map float (range 10)))] + (is (= 8 (-> train-batch mx-io/batch-index count))) + (is (= 8 (-> eval-batch mx-io/batch-index count))) + (is (= [8 3 30 80] (-> train-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (= [8 3 30 80] (-> eval-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (every? #(<= 0 % 255) (-> train-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (every? #(<= 0 % 255) (-> eval-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (= [8 4] (-> train-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (= [8 4] (-> eval-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (every? allowed-labels (-> train-batch + mx-io/batch-label + first + ndarray/->vec))) + (is (every? allowed-labels (-> eval-batch + mx-io/batch-label + first + ndarray/->vec))))) + +(deftest test-model + (let [batch (mx-io/next train-data) + model (m/module (create-captcha-net) + {:data-names ["data"] :label-names ["label"]}) + _ (m/bind model + {:data-shapes (mx-io/provide-data-desc train-data) + :label-shapes (mx-io/provide-label-desc train-data)}) + _ (m/init-params model) + _ (m/forward-backward model batch) + output-shapes (-> model + m/output-shapes + util/coerce-return-recursive) + outputs (-> model + m/outputs-merged + first) + grads (->> model m/grad-arrays (map first))] + (is (= [["softmaxoutput0_output" (shape/->shape [8 10])]] + output-shapes)) + (is (= [32 10] (-> outputs ndarray/shape-vec))) + (is (every? #(<= 0.0 % 1.0) (-> outputs ndarray/->vec))) + (is (= [[32 3 5 5] [32] ; convolution1 weights+bias + [32 32 5 5] [32] ; convolution2 weights+bias + [32 32 3 3] [32] ; convolution3 weights+bias + [32 32 3 3] [32] ; convolution4 weights+bias + [256 28672] [256] ; fully-connected1 weights+bias + [10 256] [10] ; 1st label scores + [10 256] [10] ; 2nd label scores + [10 256] [10] ; 3rd label scores + [10 256] [10]] ; 4th label scores + (map ndarray/shape-vec grads))))) -(deftest a-test - (testing "FIXME, I fail." - (is (= 0 1)))) +(deftest test-accuracy + (let [labels (ndarray/array [1 2 3 4, + 5 6 7 8] + [2 4]) + pred-labels (ndarray/array [1 0, + 2 6, + 3 0, + 4 8] + [8]) + preds (ndarray/one-hot pred-labels 10)] + (is (float? (accuracy labels preds))) + (is (float? (accuracy labels preds :by-character false))) + (is (float? (accuracy labels preds :by-character true))) + (is (= 0.5 (accuracy labels preds))) + (is (= 0.5 (accuracy labels preds :by-character false))) + (is (= 0.75 (accuracy labels preds :by-character true))))) From 95f1f20741cf1c7b257593c74afcba40bc1c0af8 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Sat, 5 Jan 2019 16:07:49 -0800 Subject: [PATCH 7/8] Add missing consts file --- .../examples/captcha/src/captcha/consts.clj | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/consts.clj diff --git a/contrib/clojure-package/examples/captcha/src/captcha/consts.clj b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj new file mode 100644 index 000000000000..318e0d806873 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj @@ -0,0 +1,27 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.consts) + +(def batch-size 8) +(def channels 3) +(def height 30) +(def width 80) +(def data-shape [channels height width]) +(def num-labels 10) +(def label-width 4) +(def model-prefix "ocr") From 3dfaca23f6cff756d9700283300798b1143a5149 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Tue, 8 Jan 2019 19:15:23 -0800 Subject: [PATCH 8/8] Follow comments --- .../clojure-package/examples/captcha/README.md | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md index 76a9438df52f..6b593b2f1c65 100644 --- a/contrib/clojure-package/examples/captcha/README.md +++ b/contrib/clojure-package/examples/captcha/README.md @@ -2,10 +2,10 @@ This is the clojure version of [captcha recognition](https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr) example by xlvector and mirrors the R captcha example. It can be used as an -example of multi-label training. For a captcha below, we consider it as an +example of multi-label training. For the following captcha example, we consider it as an image with 4 labels and train a CNN over the data set. -![](captcha_example.png) +![captcha example](captcha_example.png) ## Installation @@ -18,19 +18,18 @@ Then you can run `lein install` in this directory. ### Training First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). -The training can be started using: +The training can be started using the following: ``` $ lein train [:cpu|:gpu] [num-devices] ``` -which downloads the training/evaluation data using the `get_data.sh` script +This downloads the training/evaluation data using the `get_data.sh` script before starting training. -I encountered some out-of-memory issues while training using :gpu on Ubuntu -linux (18.04). However, the command `lein train` (training on one CPU) worked -well for me. +It is possible that you will encounter some out-of-memory issues while training using :gpu on Ubuntu +linux (18.04). However, the command `lein train` (training on one CPU) may resolve the issue. The training runs for 10 iterations by default and saves the model with the -prefix `ocr-`. I was able to achieve an exact match accuracy of ~0.954 and +prefix `ocr-`. The model achieved an exact match accuracy of ~0.954 and ~0.628 on training and validation data respectively. ### Inference @@ -52,7 +51,7 @@ INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet604530827929 INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865 ``` -The model is run on `captcha_example.png` by default. +The model runs on `captcha_example.png` by default. It can be run on other generated captcha images as well. The script `gen_captcha.py` generates random captcha images for length 4.