From f96c43e4a3cad2bec26806f115e49c87b24c5df9 Mon Sep 17 00:00:00 2001 From: Kedar Bellare Date: Wed, 9 Jan 2019 16:53:24 -0800 Subject: [PATCH] Clojure example for fixed label-width captcha recognition (#13769) * Clojure example for fixed label-width captcha recognition * Update evaluation * Better training and inference (w/ cleanup) * Captcha generation for testing * Make simple test work * Add test and update README * Add missing consts file * Follow comments --- .../examples/captcha/.gitignore | 3 + .../examples/captcha/README.md | 61 +++++++ .../examples/captcha/captcha_example.png | Bin 0 -> 9762 bytes .../examples/captcha/gen_captcha.py | 40 +++++ .../examples/captcha/get_data.sh | 32 ++++ .../examples/captcha/project.clj | 28 ++++ .../examples/captcha/src/captcha/consts.clj | 27 +++ .../captcha/src/captcha/infer_ocr.clj | 56 +++++++ .../captcha/src/captcha/train_ocr.clj | 156 ++++++++++++++++++ .../captcha/test/captcha/train_ocr_test.clj | 119 +++++++++++++ 10 files changed, 522 insertions(+) create mode 100644 contrib/clojure-package/examples/captcha/.gitignore create mode 100644 contrib/clojure-package/examples/captcha/README.md create mode 100644 contrib/clojure-package/examples/captcha/captcha_example.png create mode 100755 contrib/clojure-package/examples/captcha/gen_captcha.py create mode 100755 contrib/clojure-package/examples/captcha/get_data.sh create mode 100644 contrib/clojure-package/examples/captcha/project.clj create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/consts.clj create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj create mode 100644 contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj create mode 100644 contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj diff --git a/contrib/clojure-package/examples/captcha/.gitignore b/contrib/clojure-package/examples/captcha/.gitignore new file mode 100644 index 000000000000..e1569bd89020 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/.gitignore @@ -0,0 +1,3 @@ +/.lein-* +/.nrepl-port +images/* diff --git a/contrib/clojure-package/examples/captcha/README.md b/contrib/clojure-package/examples/captcha/README.md new file mode 100644 index 000000000000..6b593b2f1c65 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/README.md @@ -0,0 +1,61 @@ +# Captcha + +This is the clojure version of [captcha recognition](https://github.com/xlvector/learning-dl/tree/master/mxnet/ocr) +example by xlvector and mirrors the R captcha example. It can be used as an +example of multi-label training. For the following captcha example, we consider it as an +image with 4 labels and train a CNN over the data set. + +![captcha example](captcha_example.png) + +## Installation + +Before you run this example, make sure that you have the clojure package +installed. In the main clojure package directory, do `lein install`. +Then you can run `lein install` in this directory. + +## Usage + +### Training + +First the OCR model needs to be trained based on [labeled data](https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip). +The training can be started using the following: +``` +$ lein train [:cpu|:gpu] [num-devices] +``` +This downloads the training/evaluation data using the `get_data.sh` script +before starting training. + +It is possible that you will encounter some out-of-memory issues while training using :gpu on Ubuntu +linux (18.04). However, the command `lein train` (training on one CPU) may resolve the issue. + +The training runs for 10 iterations by default and saves the model with the +prefix `ocr-`. The model achieved an exact match accuracy of ~0.954 and +~0.628 on training and validation data respectively. + +### Inference + +Once the model has been saved, it can be used for prediction. This can be done +by running: +``` +$ lein infer +INFO MXNetJVM: Try loading mxnet-scala from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-gpu from native path. +INFO MXNetJVM: Try loading mxnet-scala-linux-x86_64-cpu from native path. +WARN MXNetJVM: MXNet Scala native library not found in path. Copying native library from the archive. Consider installing the library somewhere in the path (for Windows: PATH, for Linux: LD_LIBRARY_PATH), or specifying by Java cmd option -Djava.library.path=[lib path]. +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +INFO org.apache.mxnet.infer.Predictor: Latency increased due to batchSize mismatch 8 vs 1 +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +WARN org.apache.mxnet.DataDesc: Found Undefined Layout, will use default index 0 for batch axis +CAPTCHA output: 6643 +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/libmxnet.so +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865/mxnet-scala +INFO org.apache.mxnet.util.NativeLibraryLoader: Deleting /tmp/mxnet6045308279291774865 +``` +The model runs on `captcha_example.png` by default. + +It can be run on other generated captcha images as well. The script +`gen_captcha.py` generates random captcha images for length 4. +Before running the python script, you will need to install the [captcha](https://pypi.org/project/captcha/) +library using `pip3 install --user captcha`. The captcha images are generated +in the `images/` folder and we can run the prediction using +`lein infer images/7534.png`. diff --git a/contrib/clojure-package/examples/captcha/captcha_example.png b/contrib/clojure-package/examples/captcha/captcha_example.png new file mode 100644 index 0000000000000000000000000000000000000000..09b84f7190fab4cb391f8a3927b10f44d609aaaf GIT binary patch literal 9762 zcmV+-Cf(VIP)002A)0ssI2ZV4tr001TFNkl$qd)Kg=JPmwuTSN6f=S^pBb7Zny2R zWJ{E6%91IIGHG!GH-H2PR$;5Ed-E>2oOvh`g0KQSx5g`fB>Z>tnC%C^w(T2xTUdOfJL=Hzse)#S6+cuN-_SaD$Q` zA;ZuM0bqH|Pj)(6f!O+^Ns+YV-1}EPY%k5B%eph^&%>uOK4VP(7tT2VoNb-!|Myh$ z`NwHCQL>fJ{p!~H0eMN0Jmw&m5$EwAzdmzo@pjL~EMt|R{BIY29r|I7RgUjJb}PSq zs`)%XefdQ_sp)JH+aXxV4<|pFOz%jMG0v;DmM96qqx706+5ctj`jDAaN_RUQKlF1` zDAWg|{YtA-%@-O${jtPX%a3NxwDW~PyNQx<7MJ|eBe%AUyeMYm>|JVEpSUj)mT7BR zy}25zOioVMYxT3Ob3YqU#=(%y7D%LG5V>u@t4QfaQdLZd6hHWuam9cJ5E1i=jC9r?K{(1A6`V%3T0XBo?Z98iNYZw_IAmok{vkWcTr2SI;xkM%XX+};6 zQ1A2o9k2^qA-I>tiIOeZK7Hqmq>MWv5P_?mY8tudZnwXco6xGIofOIE7r(f1@8%M8 zzFN89SLrSi1Jwu9uJOIs6C&XLpBoX7W)_Qm8%x9D_vOcf?P3xMEKs)`=UuJms$P7T%uhhQl3 zxA%_XLiW;L%e0M;hzJC6p2T_bM6|RaW3<*hMd=uu<}&aGT?Q96W^JI7?c+_Ea)GM~ldjl!7U9Ct7}f>#r$gOfVt}dC>3&e>nPsA#WgbLxya-{jJ@Elv0V7Grid*?8G3^ zX8_0mz!Cre-v6h-LV?O`-n$;}9US}7!MC^GsR3iaXt3NOAp$~N#7U7n5iK2f{@$U{ zvEk8jwcP9VF45JcqQf)diGMu7aUQqgcAG7v zEC+xR4+q`qKOOp6Giq#Ide&m)CKP8$H||}YzH(k&PACV#34jb)WxxUw5P-E71#r(H z4ia)x%sm@E(~@(DxRw1Z$r4$}Bu&6tTUc}*1`Irg9Pclm$3x%+#}6D!lB80p=m9g< z9LrUn7x~20)ob}=SHw9gX>$-v40=m>oMcyF0Y{C6YSLppZ%u0};As zTjwT|J3rk0y+&An=wOmeelhW7O1aJyN{qOt2f&G3nWxE<*oIx#<(&7a3OFx^B>>2a zyj~kRQJiSY`7@J$>F{1^v)n1HASLufenVVl?>zUmBb-OD8MlQETQ7=Yp|{x6aW9RT z0zV5*jvgQM2A&kVN-3GknWfn&dAp4Z0vrG&Ly_JV-1h!$$SAF8S zp67Z4wl)wB40;1!-8}zQ{CCC4I$#{?X7OCUIMupci7Njvf|ESCH8a^=?1D#7K#9UK ziz4TV$!=cc*QRfLHhDG=3j}T%fgw`X%iK5jzY&o?GB{F=s{M_s6L*|FVo7yAxq2qc zBs84FrtN(3W@i!rA`zCn4F^UH3eQjmXb}urG6bA}dMfVbNt`9aPr_MR7mTq=P!_=V zSH~86i`R?E%+bh^mD!T*NtI@K=KKEB>K*_#B3Nd!-JAdV-lddhE;u3gcb9gJ?H+xC z%PrCJ=KQUeH=jbr1S3Nx*S>P=;o$z-ShLyy07p3eheIzswC7-})4t+eP17{V3JOqC z`r-7MJ?)$C?SFG~f|oh7N-%;6+R z6>6q4b66jG+P$xD24{Kp#l)9gkr*C z%}lY7d0M=B;8ZoLI>K3I7XTc`aU7>ss{ZigTf0g-!`QPu9j8evX}6O0=7D8o$T;)3 z8?x|F?Z7MJF9vP^0Nmrqo;thz!BmY6km}oyl<91#=w&8*dj0m}!Rr&8=9xA+ zqXGmg`{l!9hZ>dolQsD+EL}9r7;Y_qAnqotDTek`Muy9cjojGiU@a0MA);`E5W;bsb)S4R#Cb9r?d}zc1polo@&n%3 z3-xE;IrLV^+sxz%{jgE3|MaCF{p{Hv4T%BYdZ1|&HCw3S{Edr6QABk9)$k zxdlSx%nAa8M9Kco6P(~$1j0YH_jzBrXbBjQz=YB-C%$}21jB=f;rav{v>oi)KBrG% ziIcTo-9B&oxe9{FjfN}D@y7lg)#38PY}&{Vo3+NCfsq#~&(w=r#G{hCp*I}oi32Ro z6M@djBPT1RvL{@o0kj3gl(LqbySX&^o5>H_>D*K0*mr!L>xf582xN#Hkt6VJM~b$N zek6GPJ=HQ}1F)GzU|j!@DP+9Nim0Rgtf)WKIV}%e0#tw`P zjBQ;{L{8WY8~bWwFYJFV+SpUFU}w~MHyJTJPTt<7upF$QY4wkN_iCd(bG=9=c>3kdnMJ9lts$`u&lz+ z4>u=xqtxnF(0_Oxf|XXrYGX7YlCdf+))L&yllk8KpDuoWE54PnEJ>wj+}n$HmU^A< zz4*o^1Y2u+DoIS5P>u|3VN|C8&(6=!54crFIIimk-ui#OaxDuXs-m*_43kWLaqXPU zB{*mj-4AXP01=S_AYiU=bEh#-|M}>9Pc{8ss3J~dD-FvKLm&b&CU*;CjJ1}r^%eEn z)aoNh^ZF9J*j+--0D-JU!PXKS7s+q0pPpX0-Gwd%Kt8D=)v3OiTpgPqW5^;e^u2Xl z_@qb{pX!qh%&HAsD`$T1~W+ zN^V!0*J-*aiYSPRq5uE{Mr+D85|y-Mdp2uzBPnXG@P!k(PQ@*1HZ7^&Uq1cim2)yr z*S2h}?G?!vv*!|%A|hE*z1*`-Ev)|e?%B^5zC7Q#^skqHbA9pFY_ApP@roYrujN)Y zw>;;Ef`F7T#DO70iwqb5l)UJTgRhjG5)lCcAboZBD>8PAZLUS4y2hmkDPFoC}(-r&3Y|0pYj}C#vMRJhxlH9 zcQveRBEBnAs{ruv^qKxI+cFn6k4zx~F^BaTmYHmI+rOLmXpRks7Gs87w3;|M`b@N`F$G^J% zyMH?L)8+%!x7}K`yvQe}uIE@Vz<~MB?R_?I?r#|c1jg80gLRQ{K)#W?_)x?}(l6GO zHv84h_xE_a-rW6~b`#aku>$-bJM$b=PfsFc&TAvS@$xkl)>1KQ@3l;YS5P9U1w%oWjH&&zZh&Z6n7fttr z$&ZiDAN<~t*Om#ETFzw)bNNDQ@5kpDXN1TZGljO9u>!d(fT0Lo4Ypm({p#v(-`V|x zX4EKoQ6t!1ty-KViIUP}$gv7#(6l^~k{?BGRCdY*8Q@H~&dQkO^473<$mRaSMe^zG zzruAqS?vcO?4QK+K?#z66 z>+KHpG`OnQToBm4^nL+7>5S6h)y6$uj~7o#5EMqv5v`7zl`fNB{`Zq^)A^ zR(I-RaAS*J9wkP9`>grqv)~|Wxv2bKZwYi-3VT{#l z_3t14UVPz$dwPM4Mb3aBJYblJv@%2%fdN1yfPPODkzUwb5E6;H+-npD+ zSz0LT0JxZFdA@ULpWGL@;TCmToX6K@uE#22jD;eI+#vY2Q(p4IDvu%_Bql?SnaDqy z{ggW>b&3wpZIP)=X7ZtN#K;Mo!O(v_{PR1VnUAK=WO=4hBS+u>0E{!n84#JmWJ@_{ zs8!1^H=ZAk8e7r-+r_#y=Nl626Fcqigb}3>5cx zfx1K8OVbmzpgOn(f@LP#owv-Oi#}&_U0w#ZpCAWT3Nx+F+6fZD1ZV2jpM<-gFE&= zy^z^jXL7?m7%Tx15g4$AE-u}=EGcujvojn%)EIA+>TB)$&zC2@0I+B1AFJ$xAL#@ySANKWjV5gPj!w=^ zn!*?jh8r#zN%B4Kh%>%#@7``N-Z9u5EY(V0>FL#Bt#zf=cfvw16rSTVpDdApdvw66CS`a$G{p$O1`Eli$fd79-#u9Q-R)?kq#X^Jp-dFZ+RX^wRf z_(8o=fA5)hMsdWk4geISG#Ld&|AD?P+9=1Ddd>kXC%-aPcZXDi2o*dBkdxzbTypWw(3e=z<=iARp*hFM*rJNZnP$9~{NLAc4uk-!UF$8oJkhL&23 zS8iR6vxFhz$T{Oc0GfiH8)3MSn(F7?kz=#m*i{?xnNQYoh9VNk7;GVIiO766;DI0j zpL@G%JJ(W`s6>*KB$rF~$&%GJU&;aujyeYq?K~K6Q;|k~X?)j#llx9Y7#5w(0-*zQ ze6f>uGM!qo+j5E?4UrS}f0jR4pWu>T8t`jJ!UMuGL{O-Ekt`eq0#+5stSW~H{ouLr zlfLhfwh74_y~#L_3sW$c<+dOqofM#8ZG3yVwPn#X%Y(2)b3Xvbad*aPgLobwB zi4!YObxm9us6D;XoOSV#nk)ct9hVV0nZqpO4Bf~B&ZJ3CPoL@2xsj^mzy)xD0syL5 z0lG6{PB|jRjG*65j+X+o>7&2zL{o?w$ zZuYe5>(-)A&FPslu!#h_t~%?lY8 zz$g^Kf#LCGqoyyrH+xepIiF|=09IS8jVd${0Yhs{$bvVAUtg&h^aE`D{W^{_5LE}f zYDhj>N}-sH?|gjWPjUXl?v2<8{W;$1P4%hSxMT;w_xyTgs4`M5^*TM0000Cb4^B0n zpUdYjE?fbCz2!Y;Tj$;!c_Uo0u`^<+%@CL+BQ#q|md;K7T{4LrK|_$bz4@-!Scl+M z@$@81WG<6D?c2DlFUt>%qKyucCz(u(G}9>)qU4rF>N`U(Oq5KtY^QU7X?@vB`|KhS zIhyIj2v-;m$bc+~lm|RmsrCz{?zZk$E7c$fwtg--@Pgy}kKOqDREHhauWK>7op^ajj(>aex%>l+19ze<G&!_TjtjZLqg6lcix z500%ou)0d0_l3Bpdl?l53^UG)A+?#J*ynX4_XI|L`!F~|n02vTPUf4_HXhkW{N~azG`Z3rcHZX@vh9l3YTnPn= zzO3Ku-EL*`{qj(2tSE3OEgz~LaD|t4Enz!@92rYZ{`;>_mrq6r&B%`cAlJoWy10-o zCOT=rHZod-gb{-TkRx*hA~Jdc2PsyP3LEgcQcfNcQaROp{$AX2m6P&<2^K4PW9C+4 z-;nFN_cy`-5+zSLVZY1IXaCxY+nHaWLX#UO5I5#-TCj{G5Fmmh`o~bV;DIpUWey1- z_6tVv$lgQC0s2+d4t;ekGkIQqt-%ly1MFnob~fL}#naPg5|sb|AQF*FBFx;8VCQoO zPI{hS?=}AK)Q73bkh4Tf8>0_GbqpU$0Kw4Fq!iIRXwnp_lBEIrpe_Uzsh1EpH%hp~x6 z>&+8^02Hiela7j$JOO~BD6ULSEYM=E_50U9`oZw)M|K~qiP}SnQAKg}=GDTe{tG4s zxufi+X9(8i6247mAAM3ce|yV#$YP9~u*{=s5iQaZ39Hb?lwQ6wF&GSZYvM8&dTAo*D270cGkCz5dgq~DJ;>3A?Uuk zK0bW-FWoPC5EB>+RBH}hN+-HD1^^;(EO(6vrF(Mt=(GDyG|P>?B8l8EWMN2QYBJ;q z9EshhA}?}nRi<-kbCs&I_rBaSJW_+&<}a-%ipAbiUKE+l``!m3k|n8SE1g?Q;WE~8 z2`lLAV=JVt#6wQ_0{}7{5(BeIJ9iX$nC4ll)9Um(eT!qIvW@d(cCi&}i9UursPfuK zurto%N**ooCG7z9s0_igW5;XdZ>-66XBFphOU_B1-p@B8NS(H``INGY3s;T}9$p*D zey#CRYoRTW0|zoxxcYqOBJ(*|bVxW>_?dIS866xN7_KyG!8hdtC9m}Q{+Ivl>)%K> zMUK^Swd|Dob-P5Mj1r#1gwYPgES>GNxMG1HF5jZER`Lo}T%Nu>nckrg5bsx?8D*r+ zHZO==#U*T?&b8S10QiH{@$u@BV|$M-UFzIY(@NNa)Iza%r!#YRes=G!v0AC>dw!fH ziArX=v!C7js|pPo^3Zwlz{$W1CTAy!A#es-A&Dvu?C^$nRGQ1ZEbredG23fhnVE>0 z0OMhQ+7XKHK3Bhh8wg_kfq* zuM1eUMwfbvEj_p7cKZ$mSX*+!s2WWrcTVlxFog2>T&|1t30~=C0$hH5@YLbqgWZL= z$Ft02ZNBikn;#BAZF+I~mBXjXL3uG<`smspXP}k(d9IZJrvy>es{nvs4ZBMT5uiqf zhzwVnlFR$|D(FH-PPCx9{X1EFO`+!X!&7168gZ+OP zq95>@?Y7Qb`?KMOTriH+d6CJygi)h8G!WJrL4BFFG)v<=Y30^xV+FDP>eTL_uC5*P zcY_THUX?l~qNSqa-SL2kxq>(uV>(4VBk!e4#f9Wj9%{z9wG=341e+)KpD20J?BYF{ zr4l3r0Gxx5UFvDvFzlBXd6s3jJJWw^eO3sC2EkBZyj?-Ft>^2dApm%5*Bfg$VAOE= z#udQ@LDBbxNLyC$dR{qL-9K8{v$s4Nh9LkXN?u&J(%0}S(-m)W;0G97$NtA?i z?riJa9)DMrmvg0*uo^WF9d5xX+ECTjtL7?j(gwg8_yS_=8k==h$0= zQO$F`&|EKX>^shrB$u&GAYuSO1Ugkt4{A~6nbBj-q2@|saNP>>Rd?;I+@#_>{?U+N{+ksf+_CF>+c&Yqscdr_gEh3i*0SFF=o&Ju&VBm>d zeEJC3P=eRK#F7{N_{b0c_2O^l^g>E0bAlt$oYXUfuu_S3n)Z7s@2WkU*YoqH9 z3;--p28Gs|k%NFPi8A*GqRNqxLqh{YQS^=0-FgN2@+S56p6^$9`TNJ-P?z+Sxsy_n znjASMN|Lp|k3TNDg(->|eHSy9c&RA6=x_iaZ2}&Qly~+|&HO{KjpCrPUqZxp_rLjn zzWVo$D%=`lf&*trK!(VehzHGJXt>l^KBXV|Vc_~6IiwA`0ubXoiFtCb+u{xffM`WE ztIXu*SI*AiqK+&wgMg^9B!ZHNBL5+mw1$4@yIuf6rixIM2HeW=%Atch4h)2~wP#tk zs4Z6sy2st0oOt)f;;j!RPbVs7E=SH&lO@(_RF>(?DsquIoH-m2p|HO726*+*?(y(H z1;GGN@}l9WF%s@9x=IzsY9ei25WetSatHm|+k3vh>h+N!YlzJKq^m-6CBEKRs$ZGD z{PLlfBR^6~C7G1DoLRivmh+OO=n7;=1U_?{eCUNkC)YiR`t5;*uOE12>B9hy9(s<-C+qmfB1jiHwQR0dN3-EmsO_VmVfH_|C~U z8gG$hJ>IkSpY-8 zF*o^QrhjwccO#V@I|g@UCRbWto4GMLKh@2;OF_(pAa?&9g~oes&xxV(Q1DQK{nAh{ zu*(@PL%4frWTdjQ7S`52wyguy>poL*qxTNH)wlEiaPu^26YC5aLym|D7D!u56j;Bi zqmzweJIceAp#0Fjo*uSOuwQbR^Sz_5eR=b%13UKzUg&}d{2=n8wcBc5XKckA` z@Sa2K-VyLP*CBi}efRz~#&MDS)4`v^9w9@*@A!`9FVo?{CvEi;=irKVpe25s}c$n`9d_#f?c^`^*Y|Gm9i?g7YPSalYU^GE^EAHhq<^W()Z(2lZ+tZg(_t2% zk3r$p+{#f$y_L~Ntiay7)o^HRe4=x$Qk3E)X-dQc2L%uSS!^hD)HL8tAHtrrD zs|y#ugKs`7iXzUFd%f0gu70o}Iw@p8E6S@0?pHM`!Z}lg^1-3;AHMKry)yJA?1A8zY{aSf5s{jB107*qoM6N<$f)Wn*ga7~l literal 0 HcmV?d00001 diff --git a/contrib/clojure-package/examples/captcha/gen_captcha.py b/contrib/clojure-package/examples/captcha/gen_captcha.py new file mode 100755 index 000000000000..43e0d26fb961 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/gen_captcha.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from captcha.image import ImageCaptcha +import os +import random + +length = 4 +width = 160 +height = 60 +IMAGE_DIR = "images" + + +def random_text(): + return ''.join(str(random.randint(0, 9)) + for _ in range(length)) + + +if __name__ == '__main__': + image = ImageCaptcha(width=width, height=height) + captcha_text = random_text() + if not os.path.exists(IMAGE_DIR): + os.makedirs(IMAGE_DIR) + image.write(captcha_text, os.path.join(IMAGE_DIR, captcha_text + ".png")) diff --git a/contrib/clojure-package/examples/captcha/get_data.sh b/contrib/clojure-package/examples/captcha/get_data.sh new file mode 100755 index 000000000000..baa7f9eb818f --- /dev/null +++ b/contrib/clojure-package/examples/captcha/get_data.sh @@ -0,0 +1,32 @@ +#!/bin/bash + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -evx + +EXAMPLE_ROOT=$(cd "$(dirname $0)"; pwd) + +data_path=$EXAMPLE_ROOT + +if [ ! -f "$data_path/captcha_example.zip" ]; then + wget https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/R/data/captcha_example.zip -P $data_path +fi + +if [ ! -f "$data_path/captcha_example/captcha_train.rec" ]; then + unzip $data_path/captcha_example.zip -d $data_path +fi diff --git a/contrib/clojure-package/examples/captcha/project.clj b/contrib/clojure-package/examples/captcha/project.clj new file mode 100644 index 000000000000..fa37fecbe035 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/project.clj @@ -0,0 +1,28 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(defproject captcha "0.1.0-SNAPSHOT" + :description "Captcha recognition via multi-label classification" + :plugins [[lein-cljfmt "0.5.7"]] + :dependencies [[org.clojure/clojure "1.9.0"] + [org.apache.mxnet.contrib.clojure/clojure-mxnet "1.5.0-SNAPSHOT"]] + :main ^:skip-aot captcha.train-ocr + :profiles {:train {:main captcha.train-ocr} + :infer {:main captcha.infer-ocr} + :uberjar {:aot :all}} + :aliases {"train" ["with-profile" "train" "run"] + "infer" ["with-profile" "infer" "run"]}) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/consts.clj b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj new file mode 100644 index 000000000000..318e0d806873 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/consts.clj @@ -0,0 +1,27 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.consts) + +(def batch-size 8) +(def channels 3) +(def height 30) +(def width 80) +(def data-shape [channels height width]) +(def num-labels 10) +(def label-width 4) +(def model-prefix "ocr") diff --git a/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj new file mode 100644 index 000000000000..f6a648e9867b --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/infer_ocr.clj @@ -0,0 +1,56 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.infer-ocr + (:require [captcha.consts :refer :all] + [org.apache.clojure-mxnet.dtype :as dtype] + [org.apache.clojure-mxnet.infer :as infer] + [org.apache.clojure-mxnet.layout :as layout] + [org.apache.clojure-mxnet.ndarray :as ndarray])) + +(defn create-predictor + [] + (let [data-desc {:name "data" + :shape [batch-size channels height width] + :layout layout/NCHW + :dtype dtype/FLOAT32} + label-desc {:name "label" + :shape [batch-size label-width] + :layout layout/NT + :dtype dtype/FLOAT32} + factory (infer/model-factory model-prefix + [data-desc label-desc])] + (infer/create-predictor factory))) + +(defn -main + [& args] + (let [[filename] args + image-fname (or filename "captcha_example.png") + image-ndarray (-> image-fname + infer/load-image-from-file + (infer/reshape-image width height) + (infer/buffered-image-to-pixels [channels height width]) + (ndarray/expand-dims 0)) + label-ndarray (ndarray/zeros [1 label-width]) + predictor (create-predictor) + predictions (-> (infer/predict-with-ndarray + predictor + [image-ndarray label-ndarray]) + first + (ndarray/argmax 1) + ndarray/->vec)] + (println "CAPTCHA output:" (apply str (mapv int predictions))))) diff --git a/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj new file mode 100644 index 000000000000..91ec2fff3af7 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/src/captcha/train_ocr.clj @@ -0,0 +1,156 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.train-ocr + (:require [captcha.consts :refer :all] + [clojure.java.io :as io] + [clojure.java.shell :refer [sh]] + [org.apache.clojure-mxnet.callback :as callback] + [org.apache.clojure-mxnet.context :as context] + [org.apache.clojure-mxnet.eval-metric :as eval-metric] + [org.apache.clojure-mxnet.initializer :as initializer] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.optimizer :as optimizer] + [org.apache.clojure-mxnet.symbol :as sym]) + (:gen-class)) + +(when-not (.exists (io/file "captcha_example/captcha_train.lst")) + (sh "./get_data.sh")) + +(defonce train-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_train.rec" + :path-imglist "captcha_example/captcha_train.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape + :shuffle true + :seed 42})) + +(defonce eval-data + (mx-io/image-record-iter {:path-imgrec "captcha_example/captcha_test.rec" + :path-imglist "captcha_example/captcha_test.lst" + :batch-size batch-size + :label-width label-width + :data-shape data-shape})) + +(defn accuracy + [label pred & {:keys [by-character] + :or {by-character false} :as opts}] + (let [[nr nc] (ndarray/shape-vec label) + pred-context (ndarray/context pred) + label-t (-> label + ndarray/transpose + (ndarray/reshape [-1]) + (ndarray/as-in-context pred-context)) + pred-label (ndarray/argmax pred 1) + matches (ndarray/equal label-t pred-label) + [digit-matches] (-> matches + ndarray/sum + ndarray/->vec) + [complete-matches] (-> matches + (ndarray/reshape [nc nr]) + (ndarray/sum 0) + (ndarray/equal label-width) + ndarray/sum + ndarray/->vec)] + (if by-character + (float (/ digit-matches nr nc)) + (float (/ complete-matches nr))))) + +(defn get-data-symbol + [] + (let [data (sym/variable "data") + ;; normalize the input pixels + scaled (sym/div (sym/- data 127) 128) + + conv1 (sym/convolution {:data scaled :kernel [5 5] :num-filter 32}) + pool1 (sym/pooling {:data conv1 :pool-type "max" :kernel [2 2] :stride [1 1]}) + relu1 (sym/activation {:data pool1 :act-type "relu"}) + + conv2 (sym/convolution {:data relu1 :kernel [5 5] :num-filter 32}) + pool2 (sym/pooling {:data conv2 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu2 (sym/activation {:data pool2 :act-type "relu"}) + + conv3 (sym/convolution {:data relu2 :kernel [3 3] :num-filter 32}) + pool3 (sym/pooling {:data conv3 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu3 (sym/activation {:data pool3 :act-type "relu"}) + + conv4 (sym/convolution {:data relu3 :kernel [3 3] :num-filter 32}) + pool4 (sym/pooling {:data conv4 :pool-type "avg" :kernel [2 2] :stride [1 1]}) + relu4 (sym/activation {:data pool4 :act-type "relu"}) + + flattened (sym/flatten {:data relu4}) + fc1 (sym/fully-connected {:data flattened :num-hidden 256}) + fc21 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc22 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc23 (sym/fully-connected {:data fc1 :num-hidden num-labels}) + fc24 (sym/fully-connected {:data fc1 :num-hidden num-labels})] + (sym/concat "concat" nil [fc21 fc22 fc23 fc24] {:dim 0}))) + +(defn get-label-symbol + [] + (as-> (sym/variable "label") label + (sym/transpose {:data label}) + (sym/reshape {:data label :shape [-1]}))) + +(defn create-captcha-net + [] + (let [scores (get-data-symbol) + labels (get-label-symbol)] + (sym/softmax-output {:data scores :label labels}))) + +(def optimizer + (optimizer/adam + {:learning-rate 0.0002 + :wd 0.00001 + :clip-gradient 10})) + +(defn train-ocr + [devs] + (println "Starting the captcha training ...") + (let [model (m/module + (create-captcha-net) + {:data-names ["data"] :label-names ["label"] + :contexts devs})] + (m/fit model {:train-data train-data + :eval-data eval-data + :num-epoch 10 + :fit-params (m/fit-params + {:kvstore "local" + :batch-end-callback + (callback/speedometer batch-size 100) + :initializer + (initializer/xavier {:factor-type "in" + :magnitude 2.34}) + :optimizer optimizer + :eval-metric (eval-metric/custom-metric + #(accuracy %1 %2) + "accuracy")})}) + (println "Finished the fit") + model)) + +(defn -main + [& args] + (let [[dev dev-num] args + num-devices (Integer/parseInt (or dev-num "1")) + devs (if (= dev ":gpu") + (mapv #(context/gpu %) (range num-devices)) + (mapv #(context/cpu %) (range num-devices))) + model (train-ocr devs)] + (m/save-checkpoint model {:prefix model-prefix :epoch 0}))) diff --git a/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj new file mode 100644 index 000000000000..ab785f7fedf2 --- /dev/null +++ b/contrib/clojure-package/examples/captcha/test/captcha/train_ocr_test.clj @@ -0,0 +1,119 @@ +;; +;; Licensed to the Apache Software Foundation (ASF) under one or more +;; contributor license agreements. See the NOTICE file distributed with +;; this work for additional information regarding copyright ownership. +;; The ASF licenses this file to You under the Apache License, Version 2.0 +;; (the "License"); you may not use this file except in compliance with +;; the License. You may obtain a copy of the License at +;; +;; http://www.apache.org/licenses/LICENSE-2.0 +;; +;; Unless required by applicable law or agreed to in writing, software +;; distributed under the License is distributed on an "AS IS" BASIS, +;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +;; See the License for the specific language governing permissions and +;; limitations under the License. +;; + +(ns captcha.train-ocr-test + (:require [clojure.test :refer :all] + [captcha.consts :refer :all] + [captcha.train-ocr :refer :all] + [org.apache.clojure-mxnet.io :as mx-io] + [org.apache.clojure-mxnet.module :as m] + [org.apache.clojure-mxnet.ndarray :as ndarray] + [org.apache.clojure-mxnet.shape :as shape] + [org.apache.clojure-mxnet.util :as util])) + +(deftest test-consts + (is (= 8 batch-size)) + (is (= [3 30 80] data-shape)) + (is (= 4 label-width)) + (is (= 10 num-labels))) + +(deftest test-labeled-data + (let [train-batch (mx-io/next train-data) + eval-batch (mx-io/next eval-data) + allowed-labels (into #{} (map float (range 10)))] + (is (= 8 (-> train-batch mx-io/batch-index count))) + (is (= 8 (-> eval-batch mx-io/batch-index count))) + (is (= [8 3 30 80] (-> train-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (= [8 3 30 80] (-> eval-batch + mx-io/batch-data + first + ndarray/shape-vec))) + (is (every? #(<= 0 % 255) (-> train-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (every? #(<= 0 % 255) (-> eval-batch + mx-io/batch-data + first + ndarray/->vec))) + (is (= [8 4] (-> train-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (= [8 4] (-> eval-batch + mx-io/batch-label + first + ndarray/shape-vec))) + (is (every? allowed-labels (-> train-batch + mx-io/batch-label + first + ndarray/->vec))) + (is (every? allowed-labels (-> eval-batch + mx-io/batch-label + first + ndarray/->vec))))) + +(deftest test-model + (let [batch (mx-io/next train-data) + model (m/module (create-captcha-net) + {:data-names ["data"] :label-names ["label"]}) + _ (m/bind model + {:data-shapes (mx-io/provide-data-desc train-data) + :label-shapes (mx-io/provide-label-desc train-data)}) + _ (m/init-params model) + _ (m/forward-backward model batch) + output-shapes (-> model + m/output-shapes + util/coerce-return-recursive) + outputs (-> model + m/outputs-merged + first) + grads (->> model m/grad-arrays (map first))] + (is (= [["softmaxoutput0_output" (shape/->shape [8 10])]] + output-shapes)) + (is (= [32 10] (-> outputs ndarray/shape-vec))) + (is (every? #(<= 0.0 % 1.0) (-> outputs ndarray/->vec))) + (is (= [[32 3 5 5] [32] ; convolution1 weights+bias + [32 32 5 5] [32] ; convolution2 weights+bias + [32 32 3 3] [32] ; convolution3 weights+bias + [32 32 3 3] [32] ; convolution4 weights+bias + [256 28672] [256] ; fully-connected1 weights+bias + [10 256] [10] ; 1st label scores + [10 256] [10] ; 2nd label scores + [10 256] [10] ; 3rd label scores + [10 256] [10]] ; 4th label scores + (map ndarray/shape-vec grads))))) + +(deftest test-accuracy + (let [labels (ndarray/array [1 2 3 4, + 5 6 7 8] + [2 4]) + pred-labels (ndarray/array [1 0, + 2 6, + 3 0, + 4 8] + [8]) + preds (ndarray/one-hot pred-labels 10)] + (is (float? (accuracy labels preds))) + (is (float? (accuracy labels preds :by-character false))) + (is (float? (accuracy labels preds :by-character true))) + (is (= 0.5 (accuracy labels preds))) + (is (= 0.5 (accuracy labels preds :by-character false))) + (is (= 0.75 (accuracy labels preds :by-character true)))))