From 777b1b732dde8d8ee32d96b3b34fe71b63a35841 Mon Sep 17 00:00:00 2001 From: Punit Vara Date: Mon, 20 May 2019 14:00:51 +0530 Subject: [PATCH] port code to tensorflow2.0 Compile and run code with tf2.0. Eager execution is disabled. Tested code with python 3.5 Signed-off-by: Punit Vara --- tf2.0/README.md | 47 + tf2.0/docs/FastCells.md | 57 + tf2.0/docs/img/3PartsGraph.png | Bin 0 -> 25667 bytes tf2.0/docs/img/FastGRNN.png | Bin 0 -> 13187 bytes tf2.0/docs/img/FastGRNN_eq.png | Bin 0 -> 10509 bytes tf2.0/docs/img/FastRNN.png | Bin 0 -> 11485 bytes tf2.0/docs/img/FastRNN_eq.png | Bin 0 -> 4709 bytes tf2.0/docs/img/MIML_illustration.png | Bin 0 -> 23944 bytes tf2.0/edgeml/__init__.py | 13 + tf2.0/edgeml/graph/__init__.py | 2 + tf2.0/edgeml/graph/bonsai.py | 180 ++ tf2.0/edgeml/graph/protoNN.py | 191 ++ tf2.0/edgeml/trainer/__init__.py | 2 + tf2.0/edgeml/trainer/bonsaiTrainer.py | 560 ++++++ tf2.0/edgeml/trainer/fastTrainer.py | 527 ++++++ tf2.0/edgeml/trainer/protoNNTrainer.py | 219 +++ tf2.0/edgeml/utils.py | 339 ++++ tf2.0/examples/Bonsai/README.md | 67 + tf2.0/examples/Bonsai/bonsai_example.ipynb | 1135 ++++++++++++ tf2.0/examples/Bonsai/bonsai_example.py | 115 ++ tf2.0/examples/Bonsai/fetch_usps.py | 64 + tf2.0/examples/Bonsai/helpermethods.py | 270 +++ tf2.0/examples/Bonsai/process_usps.py | 54 + tf2.0/examples/Bonsai/quantizeBonsaiModels.py | 72 + tf2.0/examples/FastCells/README.md | 77 + .../examples/FastCells/fastcell_example.ipynb | 1557 +++++++++++++++++ tf2.0/examples/FastCells/fastcell_example.py | 99 ++ tf2.0/examples/FastCells/fetch_usps.py | 66 + tf2.0/examples/FastCells/helpermethods.py | 273 +++ tf2.0/examples/FastCells/process_usps.py | 41 + .../examples/FastCells/quantizeFastModels.py | 135 ++ tf2.0/examples/ProtoNN/README.md | 54 + tf2.0/examples/ProtoNN/fetch_usps.py | 64 + tf2.0/examples/ProtoNN/helpermethods.py | 206 +++ tf2.0/examples/ProtoNN/process_usps.py | 51 + tf2.0/examples/ProtoNN/protoNN_example.ipynb | 449 +++++ tf2.0/examples/ProtoNN/protoNN_example.py | 88 + tf2.0/requirements-cpu.txt | 7 + tf2.0/requirements-gpu.txt | 7 + tf2.0/setup.py | 9 + 40 files changed, 7097 insertions(+) create mode 100644 tf2.0/README.md create mode 100644 tf2.0/docs/FastCells.md create mode 100755 tf2.0/docs/img/3PartsGraph.png create mode 100644 tf2.0/docs/img/FastGRNN.png create mode 100644 tf2.0/docs/img/FastGRNN_eq.png create mode 100644 tf2.0/docs/img/FastRNN.png create mode 100644 tf2.0/docs/img/FastRNN_eq.png create mode 100755 tf2.0/docs/img/MIML_illustration.png create mode 100644 tf2.0/edgeml/__init__.py create mode 100644 tf2.0/edgeml/graph/__init__.py create mode 100644 tf2.0/edgeml/graph/bonsai.py create mode 100644 tf2.0/edgeml/graph/protoNN.py create mode 100644 tf2.0/edgeml/trainer/__init__.py create mode 100644 tf2.0/edgeml/trainer/bonsaiTrainer.py create mode 100644 tf2.0/edgeml/trainer/fastTrainer.py create mode 100644 tf2.0/edgeml/trainer/protoNNTrainer.py create mode 100644 tf2.0/edgeml/utils.py create mode 100644 tf2.0/examples/Bonsai/README.md create mode 100644 tf2.0/examples/Bonsai/bonsai_example.ipynb create mode 100644 tf2.0/examples/Bonsai/bonsai_example.py create mode 100644 tf2.0/examples/Bonsai/fetch_usps.py create mode 100644 tf2.0/examples/Bonsai/helpermethods.py create mode 100644 tf2.0/examples/Bonsai/process_usps.py create mode 100644 tf2.0/examples/Bonsai/quantizeBonsaiModels.py create mode 100644 tf2.0/examples/FastCells/README.md create mode 100644 tf2.0/examples/FastCells/fastcell_example.ipynb create mode 100644 tf2.0/examples/FastCells/fastcell_example.py create mode 100644 tf2.0/examples/FastCells/fetch_usps.py create mode 100644 tf2.0/examples/FastCells/helpermethods.py create mode 100644 tf2.0/examples/FastCells/process_usps.py create mode 100644 tf2.0/examples/FastCells/quantizeFastModels.py create mode 100644 tf2.0/examples/ProtoNN/README.md create mode 100644 tf2.0/examples/ProtoNN/fetch_usps.py create mode 100644 tf2.0/examples/ProtoNN/helpermethods.py create mode 100644 tf2.0/examples/ProtoNN/process_usps.py create mode 100644 tf2.0/examples/ProtoNN/protoNN_example.ipynb create mode 100644 tf2.0/examples/ProtoNN/protoNN_example.py create mode 100644 tf2.0/requirements-cpu.txt create mode 100644 tf2.0/requirements-gpu.txt create mode 100644 tf2.0/setup.py diff --git a/tf2.0/README.md b/tf2.0/README.md new file mode 100644 index 000000000..5cbb13bc9 --- /dev/null +++ b/tf2.0/README.md @@ -0,0 +1,47 @@ +## Edge Machine Learning: Tensorflow Library + +This directory includes, Tensorflow implementations of various techniques and +algorithms developed as part of EdgeML. Currently, the following algorithms are +available in Tensorflow: + +1. [Bonsai](../docs/publications/Bonsai.pdf) +2. [EMI-RNN](../docs/publications/emi-rnn-nips18.pdf) +3. [FastRNN & FastGRNN](../docs/publications/FastGRNN.pdf) +4. [ProtoNN](../docs/publications/ProtoNN.pdf) + +The TensorFlow compute graphs for these algoriths are packaged as +`edgeml.graph`. Trainers for these algorithms are in `edgeml.trainer`. Usage +directions and examples for these algorithms are provided in `examples` +directory. To get started with any of the provided algorithms, please follow +the notebooks in the the `examples` directory. + +## Installation + +Use pip and the provided requirements file to first install required +dependencies before installing the `edgeml` library. Details for cpu based +installation and gpu based installation provided below. + +It is highly recommended that EdgeML be installed in a virtual environment. Please create +a new virtual environment using your environment manager ([virtualenv](https://virtualenv.pypa.io/en/stable/userguide/#usage) or [Anaconda](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#creating-an-environment-with-commands)). +Make sure the new environment is active before running the below mentioned commands. + +### CPU + +``` +pip install -r requirements-cpu.txt +pip install -e . +``` + +Tested on Python3.5 and python 2.7 with >= Tensorflow 1.6.0. + +### GPU + +Install appropriate CUDA and cuDNN [Tested with >= CUDA 8.1 and cuDNN >= 6.1] + +``` +pip install -r requirements-gpu.txt +pip install -e . +``` + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. diff --git a/tf2.0/docs/FastCells.md b/tf2.0/docs/FastCells.md new file mode 100644 index 000000000..213dfe1cd --- /dev/null +++ b/tf2.0/docs/FastCells.md @@ -0,0 +1,57 @@ +# FastRNN and FastGRNN - FastCells + +This document aims to explain and elaborate on specific details of FastCells +present as part of `tf/edgeml/graph/rnn.py`. The endpoint use case scripts with +3 phase training along with an example notebook are present in `tf/examples/FastCells/`. +One can use the endpoint script to test out the RNN architectures on any dataset +while specifying budget constraints as part of hyper-parameters in terms of sparsity and rank +of weight matrices. + +# FastRNN +![FastRNN](img/FastRNN.png) +![FastRNN Equation](img/FastRNN_eq.png) + +# FastGRNN +![FastGRNN Base Architecture](img/FastGRNN.png) +![FastGRNN Base Equation](img/FastGRNN_eq.png) + +# Plug and Play Cells + +`FastRNNCell` and `FastGRNNCell` present in `edgeml.graph.rnn` are very similar to +Tensorflow's inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell`, and `UGRNNCell` allowing us to +replace any of the standard RNN Cell in our architecture with FastCells. +One can see the plug and play nature at the endpoint script for FastCells, where the graph +building is very similar to LSTM/GRU in Tensorflow. + +Script: [Endpoint Script](../examples/FastCells/fastcell_example.py) + +Example Notebook: [iPython Notebook](../examples/FastCells/fastcell_example.ipynb) + +Cells: [FastRNNCell](../edgeml/graph/rnn.py#L206) and [FastGRNNCell](../edgeml/graph/rnn.py#L31). + +# 3 phase Fast Training + +`FastCells`, similar to `Bonsai` use a 3 phase training routine, to induce the right +support and sparsity for the weight matrices. With the low-rank parameterization of weights +followed by the 3 phase training, we obtain FastRNN and FastGRNN models which are compact +and they can be further compressed by using byte quantization without significant loss in accuracy. + +# Compression + +1) Low-Rank Parameterization of Weight Matrices (L) +2) Sparsity (S) +3) Quantization (Q) + +Low-rank is directly induced into the FastCells during initialization and the training happens with +the targetted low-rank versions of the weight matrices. One can use `wRank` and `uRank` parameters +of FastCells to achieve this. + +Sparsity is taken in as hyper-parameter during the 3 phase training into `fastTrainer.py` which at the +end spits out a sparse, low-rank model. + +Further compression is achieved by byte Quantization and can be performed using `quantizeFastModels.py` +script which is part of `tf/exampled/FastCells/`. This will give model size reduction of up to 4x if 8-bit +integers are used. Lastly, to facilitate all integer arithmetic, including the non-linearities, one could +use `quantTanh` instead of `tanh` and `quantSigm` instead of `sigmoid` as the non-linearities in the RNN +Cells followed by byte quantization. These non-linearities can be set using the appropriate parameters in +the `FastRNNCell` and `FastGRNNCell` diff --git a/tf2.0/docs/img/3PartsGraph.png b/tf2.0/docs/img/3PartsGraph.png new file mode 100755 index 0000000000000000000000000000000000000000..66ebdbcf9e2a47d81270a8592ac287be70ccb21a GIT binary patch literal 25667 zcmeFY)|xr5a(B+ky4l(1WS`IT+-Mag>G$tGzk`E=doL>^r3MFwCF&-~GKR{bjaFxjRXYKnCFNc&ehS`8nt(sDNT5faM5YCz{UrmpAN67vPxlbHz&ikmF}^bS*)t`bRe?3c#(8wYS!fX`yNCC8Rub|^Np0cj{-moB z>Ma;hAv4Eui>CVNmhg#2w%}pcI}CLLKI0Kqf`hiQ**L#_{&Dn+$jzXss^j)$xN-f0 zV?)q_qi{ez#5crP8~#opk~~c#K&I-aI53Dz^>9m`f9 z;@VZxF^kxPQfrA?LFgIX<22VEM=D8F5`bKvUeM_grOwosqtfNaGq25s8cUTx)DT>l zxHg`8QjXa4p1jT3Cd=1iALQM7cqAxzDu8}pmbETjbXhsy>0Ah6s){1li|Q({$3`hr zy2fr!qLS?FK&%_nD>W+ENoKx{SH0ZDqU$5CL!ubl%W=i{esHTde8!+?`|N*_i~8?3 zq_4_d_Rs7v$h)p5mObU9Txz#Gfk)Q(`vf7pbY-D2JaI|tmcqz8^!RUE2d(D-+M=?} zM!B8h`R}^#1te4ZbIQdQ4S{X+Tz%+$8Ar~%bVt!!Iv+I7nyuz4IIj?J+GNlpY9=WN zDb5UQyh~RmYm%M>zkCbt9*_(p0j&I8yVd}d$+7xu_Qz+KQEvym+kL2%!S=W9G67V$ z=w^Sj^6j%(nbM_{x}XBqkB|NQ#kuUVHTC*M1Me$y%ci6U}R6Dv9UmAB&!uVTK{{HX} z=T&*G#M$GPr34NRg5Z5r0VOi>i>bOHmEn9t>i$U4Ue3%hqjFFBdVqvd%$mYw^EF@V z$2fsTbxhYT>2lT2bp#{}1*_@k&3C@M%2^l(ZNFl&MQh2n1@5%)VjVMTxeZ6yDlr`I za$EsR7WGJJJojY<8O>{-Q^n+}PDt#W#_o<-BF+wcfJfi-nqv@Woi#0J)8F|z0GMXp z$+0&N!39m4Zyt|Q@J2MoveFO=?GoiQ@?USVqb7WN*sUhixevnc76dueZovY0SL?z4HbAdAtO{Q9a?_JQ8xPS6kCLY`vvcmGIoY#TX#d^T>zzMElR zO|6lT!?w2EB9!W1m5wyAm|}U{vOPY~tK>siFQHVa8ogT|whNZLE$*dKUki4FwZ`2p zWxASsV5a^M1Nf|s)P32DO%uy>zX^TW`x?OL6{ZT%aQR>$1RMF&L#1J|y~ zX&nI3W?@W#!}j6Tl86eU^gXQop{JHK9@eogvrC>95rzc^M*$}*C9dI>cf4u`(BIEL zdt!x%cgAs6*-*37n_~V8?JZ71DzQaOS1+P8rY>_%`4CT9mKJKL8tSGkHbF!5Rw1? zN&oLj!A8BqE&T9=b<>i=E#)u@edZYN6ONRWR2^};aW{#B)x)>{ezz4ihG9ScpH$za z9Hif+l~wqi-n~uRLHXT$*0mH#K}F@#|MS_!cNue+@V~F-(qBd4`gVo<=|eK{w-J)J z&wCr6=~Zz5E0qs6MPjo5i-=x}KppFN(snNKRdRl@v#We3n}X8THBS!*u{Bqr(K*(O_F-U3+s{G+;0+vuG| zEy)9DAxIQ+k?aXwu=O6CwBKN0h?g6riq#bHM&Kd7fIxGGGpCJLj}dV`=*F8a!jnfh zhu)ccN>t_(cjlHqtrWT9EosQX<=^|{y~rW68*SE>lbMB#_8`zc7a&8rSn zBgb2bf$Z{pT(T^gG{;6Y=-_g*d%OAzJg`=4ON%RKtg95_>9R8lmFznP=i!qfLdtx( zWW7t7)C6RnKaE>@s|o2OM|E5rn0Y~6KJ1u;P?v1@1$!^&9r^zy(>7vNymbsdXl(aS zK-7d#ac#!|X|Ac;p)js`0BwYi4dinJ0?7=upm5)79g7cL%r(iS32%&4tLHz_409pp zxFvJ+m2mc3&*o(25K~XnlcC_TlU}%(B{cgdpX`fxm*ab;XjMA%E~AzRbuD|Fs^Nia z3PVF9+H>c{v$C}Kf9@t|Ag_?6CG)u3kI=hQDuGr%615puMysXemc-3ZbkR$?R(Pf3 zaW{>2c3ICThF23zPlwpQgoIG~+>#I64vPa$P-P3c z)O;)Wh(Yt+gr0mI6?!gLfFZq}Id-yV_)YD=Z@0($@9UXx9M)r>lDpcB0bTVHk7zBo zU192TS3rrw!O;fM%cl9lbZ}!EM0@S_yWRtl!d{B%%A6dkXOM+qxgC%RWzQm?~K7~uIB5@b|q4&Xzpiv4T zV8wy=RVB{C!(pyz{B43E#*3*MaBHiHQ7aBV3o~4Hazc~AAJF#lp)s5a*lvgSOArLK zNE^0DapTZ}SkdR`9TRAe}$0PCFcaHm3A20U;>6 z?%7BQH>P~pG8JiME~e)2iQ47r=TZJ!p?b6y%CrrfBVkJQw*{?OWZ4L0cV+dPqoUzB zbu2E=DRU50ZdG-*S;t&e;b_sPR4y)!X!s>0L4NIC^C?ycK7LJ8kO;;x}%=)o6zIAj|z+m`@7W9Cjt&{kmZmlPpDLgqqTuSoz&@q58xqOMh7Wj_u z<_oQ_75W*_b;(ZgsbL};I*AmyT;Db!++n|JdHk^HbOVB z{Qa@V(NlK3+cs$~h&P8~uYr@dvC!xY)Ah_>j%lMzW1NGXYXIpV^oX+a2v!7>#={}$ z#d9;=Ys6GiK%_xI5o*KCioN6D6dtu8X-q!x0u?z5_0>Qq2+2*)#R$WfuB3z$^u7%b zaQ(B<3?mxkYB78YRvNA9?ACtVwTML&B&UQMz>;Sun~V#m4OMpxc);W`VM z_Hrl}`iD1Q*sAW^8YyDZk9TQrrQAI_Vtt0#q!}EDG2#2tsdU?96e8Hl^W3v7dPj2fK5^gp zkPiAn7;?uV3Re8WDFd1Bvy4?{NpsGe_5$re`F&DmgKl0f%UI3s9IelL(GYj}<`~_t z@4ZplNyVaZ1HmdB(n1F%N%ER4eEkM}1G@|<1#s)CqF4%h<4TVeTLY5mC%y({m4d?` zZ!}PJCpasn(3s zXmLFdCSWkcT^*mA5|!M#Mi9@0cT3GbFp=$)r91P_q1`}wCNuaRRM%&u_5e4R5HH_D z7CUAz#t^Rs)FR_WOlE^_qK6==*0kqMZ)wHa1+o%?zeY82#d|VI2U93XIbgYW;U()Q z!vBkgZ~iyrcAqWalNNkBSU5|A9BD=Xndj+5^%)^37UcWuEB&Ok1ah%TH7O5B?#DCT ze70H3{S30V`gFXiEDg~@NCob)8h`@Q#r&$s!2WG#?ppyH-2{n4)bO!eO%v_whO|#g zTaV-i>wh+vy?H0|Pu#KTP9y!w0`9g)%y|uvu&$ry#aU@u7tYn$+|3ZDsFGRc1UfU1;p0rZc&AT&O^&4BQPB$AcHl=i$ zK9J478YOJk3(`e)U=ORSAt$d$y>+xZ^9^MN8?^Z7c+pZdprA;#Zk3~!8T!I__@w(k z3Je666hyu2t6hFgb3Kfrli`1-ORP*;Y-(rRkXmJV2&gi`#po4;@=Fawz-`!8w62Va zAU50FpzN&__X)p*Prc5937>oP7XG<1;qfaqIB$k?UR&4i2I=WQ`8-c6+<(5Zm97U8 z{rrRF-*Ss5a_vW+4_6r=!RPKUQ4ev#;$3AStXuAnhn(oK2L~-NT~|Lfwh2+1GvLMpqHrtnvbrg;OK3IDunV*r`H-LXpJ$UPBSsA2rYs6&* zjhPyEPx~N5Sss0szqN1fZWd^0Hm<&5BL+9lOgSOB*Q!zRM!)s@K;$g3h!Km75*r+% zY=jCZnqmqg!PmW}uns|@92v=2v60Q+AtgXShFOCQ3#qwX=G;l8%FjP+St#Su(biX4 z5wYi{S>$Wdd9eDcJ9=WBr8+`4lAW(jK6!A;6|F3~AY;bcPpTN`F?|Hs`c zJL~RZ4GNTVL6v{`P0>L&Fh!Pia5p_a|9#<=PZ|&XF=+Ee+=Z6UAi$&`z%i*4z2Lwb zH&UQ>0kEzV>51l(aB0KQR_3Py3j`0#V!2E7w1%EQTAQkm;5gq_sO}islvBd{{k?N3 z{%pR1XDL=vdx*`^A{f(uJRk?azPC4KvzNV(-o+3 z@!!5MR1c4KD-@MKk4`jyqP#znFD(5EC_cSFhic^2uKV=q?Z1MY&VKl9TC$zeK0~x? zMU~_0fbPScE_AEl?(f)epryO%_#X@%4Vamc{k^{z)T`X+Q{s`K`N0JPy(#Sa@}bOh zuTx4#J;DKbL7TTa;XWq4j_#$LBwFjh6@v`F1{sk|sz1%Km?nGgm_yVM!M&y?M1k~5 z-8xi-`4$Vo(I3I)_(WHaZ*_=|;VKns1c$IXd8E#6rj=N?M%01ed^t!3bh{w})jDgY z6CNijq_7f<3`u%L?(1?0)YVN#z3mq>%YgunZ{5Z9D?V#uAC;7MDpWX7KP zS!R;Eay5W)HY}66Ti0PDt4G2H3KY@0w!sR^+ny^_HxHGU3Y_)+3|5?~t`DN@jw)4Iug87Q6SH!oP2{ zjjaFMZ&VIWe=mb{M*BN9!L?A&pM?o>d9DXPLN44`CC#6rQ3;8DDxqrxYP`~eDr1h1 z&MM891sDpkIkL4Xw<%mMsLMyabp-eYavK@kg;bSj7217#?LhEvr6{0VO{Cks8lfBu zuT=E!x0!WuVM5Cxx&L?tN{x1V$an(X1=jm%c44$@Hj#j+;^M(k#OmoENdb;HcM=YB zb{{<;7f&WmC7D?JvZJP@r)>+DMxU9KzuDwh>gDZNHKvb^$>F)AE&_PqcpjOnPVQFD zc-?v$*#RfM!SOJI=-!*|9s|lHn4%A7Zz`+v<N~SSS7#`d*rq2+REgx6(;dBg}bgLVPBJA?bcL<;@>M0r~K_B z8y(CGsnv9%RKuSA@In~Hpich`!-@lB7w8@UBiu1sRDY-z`}u-i;hi-TC}j3AV*;VL z$N%qP#obQgSCCNicEjWda7TzRvE$TdYOHdJ(^hBR^HXqcy6!j!ur31~eTCAO& zSxDCXRYu<`XxrtXI0Iluao%01tu9!muIQ;rR%7rTL)F*y%UuK7Pq))gSmQkxl!Jt~T!j$y zUMF1NJKYM#TFF#pFs_dJCs&NVMb;<}zeI0b%;zifvgcerLoW;u=`lvv3Cs@vC3jj)T8YTD?Z(QfUUxIX4*l}^%u7XzDCB{&t0d~_W;Z158;~o!;PY| z@+G*A@397Djdix29UzjI`|}T0RP}hUbCP^Jzp9lI1+F({;a);O1N11gbPn8>b_%>w zP)&Fur+L0Mzb9|&fL`&g^Oqz(llE77q~4#*?iCcNN--^*b90}r_M4i-EkN(TTW;Qw z8{d`HCFGxI2jVcjFVb@C|7k}mz^xc)aq{GhTwwkjZ_Md-$QKGGsdrFvkVomeMu5kR zx5Ecw&6ZFvjCZ<(dz-w@W>+MR3)VNmR7GDGL^~^vxx=r)Z@`Zx&9-3Z`d79wo9DZY z@rtoJ*ZxV}9Rh7W%_1+TYYd^Ds9s><87X(FvyuX>ycZ3 znRm@z2#~ouqO0?PR6JP2$t^sguIXe*{+MYq-S-pDp4cwfo^Rti{2S!oOKF&0B(pIN0JfC*d#C$FFr4p6mY+vgaKhE*Umi6u=_Fh~1My z7EO#DnQs!vvFQW2Ri`J(>`04vN&i>pALOUW=GN!MRwn=WuI`SdE?=U0B$9lvZ-!e3 zCj+!Kb3EPhk2bbxS7hfUyVfhI(DdR~fOYJjVNZ^8Wl*n&-D}%rr7TJfP-GC~HMut-N&ypj;X4`9ysjuKPg0 ztjE=OHySM4Eg*MB++*X_CsMISN_(v$s6w&wiWIN@sHy&xLnvZKy#m(R&WkB9vH}$h zHn}$AigS}X`s=OA9U8B*fqHyV5tTc~{Z<(tht{cHPe%FuGWu@g348CbshSaoEMZR8>M!D$ib#k(@+_Wxo4-s#!>>`4HuPkd>}YjlQ0sk;#R8@kK^t%w zyTROm{`QvO8H});z^{xpVe6%N?+kW7)G9%J9UXH5OkBrk`rU4n#K5XbAk^O#$qdOc zDEw@B#je*xt{-5`f(5d7nGlGT40as>%@Wa&e^`&>+@aNUta)%_{iB_@T4jA^zp|)- z_a5Ji@-t}_2Arn}(~+%Ot>$aqhdZ;R*(*2lNWb(`nkA$)$Ii*QgwqB@i|&8bc%w_1>Xv*h|-CQ*u;?e??U&48jqaz8PYaLNnJ@Zlp0Pj zSSII+4v#P5Zidvp#u7i70a>omlw_{POaJsI(^nkd{>SzE4jUJB@CW~0v|U%5A>sMU zWAaJ(6Dio;8NaXtGP&^fnVQI5%Uq7zEs{B$&^X}quFU8GzM-m6b@By)G1VbtMQ`Sy zWp?ahr6S~SB1L6{MRJZ?$f!E1*w|bfi-_7ut?vfkgvbeHy7&Q?I}tIlN1HRw ztXr7|#zntp1VhlIXh%N6#|-qtmwXzQ>JKwtl5;XiiI=uLf;_)bz4BGWt0gW|&s}fN zD7R`B+xK!0T*)S=Pn}QDcjdNE@{zZ$y~IX9d1I>t3!!!|vMFl$gkNcNWax7Ro~Tr1 zj38mcL^hoK;0O9(l^QL0mEoZaZSc$8Kh!Y|`LE}LG>vD7a9I|T!1Wai{-F)p6 zJA9Lw==;Mi)_8$EpV5BiG+Kov6a6W-bzX2qe0rAy7_~1zmOOm&Mt}jG%TBpHwD*3j zkp+S2|0Hj!`pKu-81WEWvwOT6OycpMB+5Jk zWL6QrHdi1_VBg?tqtty@`?MI1tlLOr%MgrsTuo5={Em+rh zCu{8SAg5n0M_rM~gWL}N&FW!sDlpeJ2&`^;+O3wXfAwF?wpY<~b3Oa(WvA3#LWHH@ za>L=_q{=;@iPoa+Gfm8=A~h4^zo%3h&{94h4%ycJg#IOquH}q3CdIO6hJ+|iRTWF^ zL&3=>v4+&7Q>HBZL%UnxZZ&R#0LC62+n1K{Te9;9Gn@$n`|m%PSx+fIF~_ zogyHLA!h5)P~=4Koa4qN#YRNyf6T_FXgVUv$+=GF-Qni_Ajzgl zkO}qRNR`O&wYFZ99KN7nwgA60-vMYF8@C?GI6Ng<5QmE7FnAJY$_WP*x#S5{4!_f*D%A?Rdr^OEt>Z^qD>i} z*Ub1B6)sVRaWiKfJC@Uo2Bs;MqD?JJZO~(Mc_ZnVMlre)!Hd(U+wNm}%^a+%qj-s7 z(B=I>2)9^sLRr$gDPI%vk@yMQFD_%MMbyL<=aI(xuqtG{OxbR)q}o%u^Vb&)%1RB!|Iv+mf`IGmlvR9Rmt-=*yp`T zTgZ4mwz`ZWZhRPeo4wmL7;a8p&pY%NEUm`uJN-XHPZ_e&T@H5G*2SZ6?5Zk>QuRxN zj-GFE-|6iJ;?VTG_xPX>zYuZnPm4?a<6G@a=p^BQVP@>w$;AR`d+egYXKTuDOrI(7 zoy?C8j2o4k=BKdP6yJ#{5A-a?M2G$WD)8^sGxL)i7(E~G=Uu};wBrXWdE#Wz1$Kdy zqE^2DKMO$3XXZXWIW*2{tBt$Ne`&*a*vvz_j?WZR`FP~bIwPT*h=BSP&T+M1Xg1eP zYx$$)gkE-LLr3($OrFp(;&@VHPOo$L)_vRjMKM)tPfz<2Duprb9-TtB%}Xm7MuDOt za3Rv=sK+cVydD@W0}kc$TnYCFYN6L)CBo7+$mkM$PyaTY zU~HjwU9VrW{|dBXacxT6MA3nQn$s*{NMaFh%+?S+12Z!AVoLQWbBU`#@Z?{~wzR)) zX;%F79w~|dzZ;*F(+<(=Yye7SZns;!kJJ31m-a`hMMES3%2XDW&m-&{JaQbis0z>9 zrETmr%n<$73hq4m7M^-{w@9pQ>;#w*)Yr*!m@n%47Eilwn0M(jEX+O$K$YjxZqrb}quWdMUEk*2WA?DN+ zMkV1EaVF0KlPc3EK9UK~ALW-PhAo{lo>9pMXz+KB74Mio#%?$PmeIdE+IzOpyBkv8 zERX5f!*!saH0^!)T2QMX@^sBA)AieWb83(td zw1=QhXuo0Ldy%J~DO*pbD}hbnkFP%8`=w|fmW*51ZudquRm@QEEwD>}a zg$J(Bk*W-Wlg7k(^n16R@e;Ty&RSwlIDvmtPVil4#Ki~KBmy1yl9zyd^;mI*?@s3eo7qRJ2F`*`W)4v_cRnkUO zsdWSs&H4_htl=I2m_Y-TmLcgM#|=JY8F^&g1cI5ov16p{?J>y(zVo-Vgq?&Nd_LcF zaKd20@u_U6-lSXgue2w025@Gvj`(LpniGW|0q;USYH7(!2=6o3+ zQg8bSl7beVn~^TlNfo{){&M-uT2MH6B+xHJ2gb{Iq?WUzQc+UT@E=j(t zdlB~!wb*Yp0O9NPzNPPEpjOYZYYJ<(LNjG0I|a}=KvJ8!$&1A*_+Y)qUC|#-qeXV9n8)Up%!v@Z|#*Jm&>p!#@wR z!TFvckD!JZB-$qY805xO%9>FK8w(~$ff9v0Pj02o_?r-Y@Od1L_ThPpjt(z2b%K!H zfb4rPdwro7k8%T>G}n*jS)}${_m2RJ6wF$C7ai+XZkWnYvPqFI^Bk-B@jHbSM*I;= zIva@NoEoYD3RdKtrV~qUxeil`JZN#58%?ZO1qjamKRIw4TB66_tV|~%`GzdIPw4EJ zar2&n1~<pkfw2s@gyI4q4SJ6yp*F?w<5IEsEwNuNXj$BSbyba=?u&)91&j(n{ zS%`;UCv%f-H*R-o6AH; zcKspf*`0h7=U;eOKcvag9N*?nS$RaftY5YUEZ;GVEK4%$3(K z_33fP@jbPFj*lkzHJ)CLI{#6t=L!L$>@T+d!uB$`<*d+k`Jkqp$zrAU^WrimwYEX~m?_TV@^y%+?iEh-0AE*^B*}qHqY7E?r-R@= z_D)hIf>A^%mV&Y!=r3PII%#WluVWK}+HJ=GE%*8JAz~}|0}CGDBX#5VT`56~P&eiR z`bu@mYTtx6T3uj@^3;{CHLR3Ci#vXE%h4$s6-XI5vx|xXy+d(k-8q+Pq?Qx+v8xZ3 zl(t6qJr--&WoPCgOCSDj-J7XHgFd*I>3@Uj)f8Aa;XXfG#pxCB%F@s~rJ7U7gKJ_nzEn=nKF4qj(wEVK*3UF;He)rywtS&Vac+y5BD1qKZ-qr z8E$5uijRT(gQM$Cue)1oaFQZ92z6+?jMsD_hgdV<>zicF*Z=%^NK6|PmK_KLq)RL1 z!%rr?_BXXY#rws)F*SEKD!IxZj1O>X|;Y{YmqN;uzkx*gbk zn?TT7{4tf^!rauXifPFL-_sq=gURr%Vr7Oh-Zd4^ZXJ+qxRqcJ8BhHL- zJe2|cIy5qe-Vc>|uN!mAKF}sDyO`~x#3^lkpNQ#nMz_GbN(Et?a@;~NwXOK{y}qOK zW#Oak5dlk$%9YuQ>eG^LK__RT%SF%epOk^W^-1%Y)0eB$S)ZG!4 zK9)9S>9UK=QXmQ*jKwHBvX1felF=*>47{M)b%`@!Q}8wo4iAUNEb^uJ_V&_c>5;f2 z_eb6MU@Rff>w*enmjsdQ$dJcCd~|m<<#ceqkwZL_yM%%-C1~8X_J|gGK;(Z`eo^z- z;eZpRJ;|xPw>5Z`CmNXSBB16;@`m4I!OUr{m`I@mV;D?Kn zDSYE+kXPhtjDfDzEq=ocn1q}CY@(fD+b-txVQX_Owd{g0?(Uz{xt?Z$VSsKK5#su| z+L|Xr-DxTV?x8xdyf@*3MR=qR5wDkz3{~UHpO<9WHzi`p3+?P&Xcm^jvufq)QAtdm zG{P4hPTV2xNaoO*3x(Q!4qKLT{}pd!CF+%#O5Qnl`QZn<@#eW8dF!G_u!eud8i8fW z+-ZXnOg``Rq@x*B*^~cfNC!?<8t-18tr=cE-{fddwORA3v%(P7!Jl;DZiZdY42hqG6?I4uv%Z`2bDz0t({ zyhvKx?XFnB85a>eGsZeOH>Z#@;CVBh%D1}H#rc&$jxP_(8n2&D=Ti^{ZPjzE9-;SrQ~a8_mz>!71t(d~aMrM@lLGG#_@I5o%Z_2@ zva#m|_K^JW0D3sj&6lMR8e*}4Ac)E@u%VQJLh84v%- z~1Ir+m72!xJ=u;7T=je0KRlztcF8SsteP}d`VB2jr(@47y` z;l#lJlF_0A^cO?PGw#7&nAA*wbShl1YX3gd(K}^hHLTNY{H%X!U_`d;L0){e29lG@ z6ajT!BGYm4=c3-GTBm_Bd-{g?8@L+Jet>{ir+P_$Qmpnn zKknYAzSZ1$H%##~mXOR?p@v%L}!DEmM)yjmYyIVrW zU)*Y;&ah0air6hk8}xAV4!z^u}F&L0jIF;>@xOoxNN*3HUw zDV56?s&|?Q&rjzMbzbuiVq_9$7!|ugIrhrOv@-`)wKge;#u}V1g!aMOdSU8_^6fPO z!lt7bv9%pj!EC^Gt<|^Vo|tk7(7?N4P>2wu4*3~OXRu&qp!5L?l<_Buchw3be3tr} z^h;;OMH;U%ABh#D#UB41G8PG+lm8g2b{)sqS2&RD>*kfE{JvE{Gk|;nlwJWd)S#Gk zh>sr4A2xZ?XSH__y-2K0pTHphz~7cxOFDna(^_~!x<24>j<(e9Oj5Ry4QZEU;T*QpUJ*@UWxdgm`l1V7?{ITWMQ7=3di&Yn;o0RVF@SIrf^)gF(T!@Nq0JM}1Dhd!B*tBaBw zdB^R2RBTHty0Ts&TGvWzhV2K2vz0*42_ zZoK;6pBk^M-#!jtTLqu5KSBd%z@ctf=ubcQl$J4UUu;U6(W_4Evin)Y!h1Cz|*MW3)STl`+)vkNV+h(3MNod@&g zPNYm&@m9#MWlj`!w=GFi@aC)zigEcgC2mzz6IjsF)?m@)9(Fely^GyC2LJH=>(rDhxxOyFIm^YVJKXue(edg+x4Srf#P^~g8ZW9m zxK!A1R5e&G0(Rh7;>3LaVqs6XCGc>l-7~7)efbf#snk&15Y&OmY#KgKR>cxjAzcW zZi<6HXKMm}@nRUMF(Xvk@C;1EM0X)9=wD(O(af*9n+dyJd2c}Za+tUX8*=9J1!mk+ z-iap6R)}Z6tlD{bIKY1&E>HM(qkypXOj&jIJl4prnXlX}b?J>gSzo*M?aZqEw#Jh< ztzqzoLH{@>ZevJM5hy%cnFI5hMueuD4;{TbgShAMLHBS8(9&GnhQ(`^8-M7?=VBk{ z>(BIjL22R|xI`?fZTAqYf40tRyBfa?wi2qCXfGtzN&K#{PXVO{USKtvvxN#+fys6={!F&Dxi|Rs?Rbm9~-HCqAx+m1@RrO~kWwWx^l z`C5C&(hE$C=k}g?ig8<@43*qmyLS*OD{95^P1}DcW7=#eV`d6&(5)<`p_~OG$7~*( zTLhYGEc%qF90Ey*j@BLLDb0^cAQ%GYz0p9WyXZd6bQ$aaT24^|}9l~4Bm(#{}-yUabMef(}&zrI;4T)c& z=WV6h&#m}z=DV^ZwWrCRiw|!OgpCcMDo?oE4{hrVZhWWyX_CsVGcP;l@1rPq*)(`?h`hZ+p7kzO%rsm%ZT7r0~Ex47EkH+qqT__Id_W0vgn< zYS-<;79Jn zDoj0iB^m)Z?pUL-VYd@lWDgDTJuOAesv5tD@-F8wvN_K;HqBdB8!I&j1Y6t~t^BJ7 z-(45zQ^|x_2B+k6EAm0kzb;3`F=sdQ(Oez5-I7p!h{mZ=qOkeD{W~=K#ii&&GhLYR zCv;>66vr*wcYPagHnGXnv;{WNFuKT>~S@1G3(_L?JS8 z#2=H9iG_RMFBrT&IC1&LK>KNFc?a&jLnC|c}>wLKn$v9pih89Kk`J)Pw)` zB2$@R)s<{e6W!$)`#wr|gA#TDq7Kw8a%$;aSwi5<9wex;>=7-Hc3x<9N@u~>D~R+5 zOJ1L+zPHnSeA{{3#eGv)Ly5HlE`)>jgVQ-Xo44-UBb6$$!OuB)(0>1x>66n})9TAD zz1L|?L&Po+pIq?-$J*YW>0!?0t-dA+Y9b^QbC&6vu zDDJH2X$F0KyCyfRp(gDLP9ujXeH!k%yO%fAS@Ov2vELN>e-Op#j4a|drif~&FGlpo z+Ub1Km&mfZ>Wi9xrtRBR=GsKB)Im%C&;7n?YN;oy z&*iwU7jExzGN3ijcCW+dhL`BUrP`wy7x+21nW_TR6nQ!~4RdI)HO>!oU5KgvNIc3i z7O?W(x>iu>^w=_P^qv|liM?q>Zjr$0dFz@7vt>SmdjHlMruWb5yE?!)`n75~M*1Fu zp?O&UwHcXQ~H*)DyvGvvm*r)hy#_#!2}82QzXJ3Im7B{x1(lJ4fogVF&J*; zonHeV>TWtK+o_HI1@{s6ov2YAvTP(MP|hh44Hd+4bfH(#>%$HQknn8TCIiwH8Sf z|E5k3MWFrwEa0*KHZH(mHK9pccykOO3u0Q9@C2JDJ#a&}$;|@1NERcX)&!lipMth6 z5(|5si>HOHm7Cezvb~sPe6p%qCEV5_5!-g2=KFXEq{`fm=(f&nSUJ=w>0wc|o5ZH%QGx40EtjgJf07!7w=iCs zLk9+(R?z?tT|&Nm;pEFDT5>i*u(h$-MzKDpPJ7Ft1)jqLdZdH;oWyuPw&?ir@o|Oh zib`I*xt(a#Q}P~mAnZfR`zmjYsxNPN6IZsdqfgp1d^Q%fp_!aoGv9op3G?)}tL;eL zG{$d)n{gOW@icB5Cnv8!dnwMW=^3v}p$pybnjVPJoQ$lw%J+u8EgzlLBRKQ@P_5R^ z1*0O^1g9A5>RgTTUSU;+X2Na)&ZS6{-TvxgYw$yTtfB7w<57+`G}N*xV4{hRhbudD z_(qMn=lbU60$N#UwEiTwhxvdJdzHcL>(_E=V3klW8dDz-TuFHMnFF-S^Z-8Wc<@)c z9L%)q*Qry!Yvl<^$C!qh<@v2PGX~r)Y86O@-Q`)cPPRy7&Do#bZQs6uXlc_bk^=8U`H>yJ- z^7lG5loU4uA;ux}r;Gafcy@#S1m>en46MC_I5?N`vJ zzz0eLW5@OLQts2GiPWSjc4y@5_kZlMQ{Rk!>++!N=xcL%(V@+!vy!ZQ zt9G?$&PRu{Mu(M&)%g8~)$<%pdbAbeE?Up&lXX;GJ3qDGDFh^N+blwN{o`MW{AT4J zp(1&O`KBc-$%xMp3AcRv;99l@=T&bDqen7*Yw`+5RE3_VKVoG8`3u4RVV$v%+Iaq7vFLHkbE0()k2s|{&*k6p zOpQI>(9hXxT9vwJ87!Z3+#0vJA!*yOD|~+x)$Qel%PHb-IAbh#?@iU;R?Wnox#YPM zSCAyr^C=;w-?p!UYee*D7f$cpbE2%sk*>#S?W8qlW4gAiyvF!0y&CC)xmk{hC(!vb zEzqFb0T?`UqNi?QPjIypt0jT_ruv~ev?9_cgFmvpM^wOz9E;zU*AVDh(>ez_s4>#Mw?b{;nU$`zb!GG zP;c$y*Jl~2pL%2+Q7QtSSi|Pkat)R$ymv9~mks$c>|Ir%fA_QWTTI|??THsK?d{ev z=s$L@kPy2E%Jz6Ay*L&W*N;dwOcw{C2ps{_<+BxGJNl zq~?K3SH33EHR&%Z!RGyrbYOgco~Ot$ztJnw?0(wy5xGaqz7IpfZ=lXj_=7)b72s~; z#dS+ZDo8mH<#kpx%mU4 zYh5J4p52b;I>Y7J9Gfm+|K2Y4dmD?SWRLj-@SMHwh8GY3;OV)@3$PcnG>GTBaRe!gPG zi}Pt+o3mxi-u+NDOcaZbI0L4#>bBONetEIl*U_5K9g-|%*j0~!E}JLV7yBfL?=^fB z=(+pf#~Iz2y!;wz3nT_duRRmYee?IEeEG#*UaeOsNgikFIdCS!z3{xIySqa1MCqM8 z4AAHK%NrD_?I^kS=ZkQ?d);RveD<79-}I_?{L_Yml^7Amhl-_iV1?FLQO2QRN@LFS z`@Xs<+?d*H1RWT)5mHm*{^-pi*P^B7d{q)-{N^;x!~8A8WpGw@7$qX)kZZOpYd7RH zzb{F1=@p{N`O^S<$}RJBl3g^r_I$K)kf#V+vnb3UcuSl^v`?aY#vhq7M(fYn2<9-K z^9;842${sz_SJ<+RQUEveDgW$fO=Tv9ugjedY;-ooT#NSw@f4VV74e)AxJVqyu1 ztPl{wjY2r}l-BRFr4MuITPr0w!tQ3K@}GpG;njYYso6Jpdto!7WK2~Y~(N{Kf=I0D}){Z zXj}85E6fa@Rb}JBmy{0rly|gFXDTR^^lW{xQ_~TW|0lmyHh~V=aaG7&kY8(kJFJ_z zK^P7}`bR{KEPlK)9Jy$94C{_R9@^KG!F1V)QpT*RDeFXXpAr z5xys-+V(~MFjEbL)Tqi z48Z@L7?gc@>1LKks`kLg-$x@ig}uP>+}}~tb-6b4tzIUL29+wJi!_5yu z_ckmA5T*ooQfyYyakZKY$J3*KlaO6g{vDri}u~X{+dpl>H-7I@Zu%ms8r~yP=OX znNdnS4ahJ@hpX@>l$o(G_TZF1*}ha4wcbkdKse3X`M8=hok_|qs!Rh#-c!{N|-A}?5spEK}BZC5W$sQZgn`wnzB*nF>yz(o{-=|oRI?C(uKV! z3Rw(h5G-Gck_P<=KniDvYGub-k2|l2t0*f*7+Zn6K37S!*h$Kr5#kxL11%JGTPtkr zUkGO13ogx3L{qTa8~jEwy_^FBlfw7v1ws`h`eYn}|=ACL8}9 zk|F>6*!BF8@T!l_Pq6T*ZzaTT@zIGs&f~9!-sX|G>raoM>8IxqCF^yV6jcm4r*1AW#0Qn)T1rS>K@y z8cR3*ZPPS%fsWDy#T+{jTL+op_9GAHAv&okrv>}pXgv_ypg6PP#VMG z4L7@+3}l&FuGim-sgf9QV7+#mbMyU1(^$R{*B!W-?tIhX=X4FIwLwPA8ao^FI4Y?h zA?R&d+>J={w`$DIZstf(VYUq%?m%LLFJ{xhig&8D zcCEoj>~pNO@kUP04&@*lbUzWK?BO9A;GhqY%&uM^#NX8FZ$S>z z&W#(-D5z^Gwv_xnUKAlR%MlYrLNPXS+a?XW{^IQ`K#aXO5e1O9pq9@@tn2a+6BtKG=0z>R4?`2YP|tAjh)ku-chD! zd+h6ZaF7tiaM!FOIKvKks^-5IeWS`kld{Gn$fuAuz8h-^tivMO3~U|NZir0pC&DLW z542X<%!{TPj*dAV)>*ffVb|tF!o&OSz^bWx9%Pf_fX#Y_oW-Od3&PVt&trVXphQ2i zRdbY7bIZ-)M&x1<~`tH&{|JCZYn@NGY$0*f9*h2&h ze$&;t|0H=@=axN+tb0f)&u*v$C$+B+H7n^y(`V~0Zi*)roK8sg#cL3wM{`ciHdd2C ze|;!f#wBLz`dW;o2dp}uX1B=S)>Hn|NRoY8x{9Gx+dZ|n>gf{1IfrU(^zX7#893kG z03JG-xv9nXG)Q0pf!Jxdb1P@X2gzDYffHv>*_>ec4I-Pad zJfHLf75)dq_xd;|V^-nqU#xD!1Y+#Ds-)%vp4!hpjjWSWM6F#fg^o1+%onZX3O?XA z11cJm`dMU$>J-Q|vrkK`j2j6aF2|+PhG*4I+w#cM=6nGG}E{{3WKf@p?>7miUeaz_o5s7mrroNvZB? zy}gpH*YWt)<3`UDvk20 zH+*T{zCz!9-%mCRUPdfq9|%3s_Pa*uJ_jdlwl)vw+ z-t0bPQ8usnt7T04Wlh7)lO8j_?1U1NDVeZ?R3#mw(C(IfN!%wsaBFe4uHk5jgy`Xn z8(kWg>Qz9a=w*IDTiDRyC+SyighY|w))}Aw z3c4MSkR(LcP=ZuCrii@|Ows%SlrsxmbYjlx;l+{qq{LnleDVx_+LwkO)z^)SnWy7# z%{P6?nPu5=HV*)7s_I?>%pZQBZ(Wf;(Jw-T2>unKWsEbN-drKbhZUXs;|)LeIATXe zM(-BNmNMBSv0hh@4D+leHNo}X_{A_^0z9(RlA_|h9bgw$zvj#6pYdK`vUuoI42db? zK&?~bp740325)OI*Nl`n_SBS-Hw%*cye{6G66Cc_rUSTw-rfQ%ogWUfa5Bs=ldXC5 z+-IbqK?mrg9XJqrj>&+-`0AkNXF0-zANd5{nq4o$Mo?2fLIO%)2>|L@*<-4CJ9(g{ z#B&-zM(o$ifN46CDCxv`-Ai4IIcIG3z`_X1T4}9^-cvb1LdijFB0EsSqgR#{#S(Lv zt%~hupkvp7&ahAJ(5-=nth0LEOd5dy)iojNlWun<0{I>&;CH%}QVsCdqRRMjDjF~orxQ04Yp?-F2(GdBmFPu)8*{Pt}L?~q{lu7Cs#&+W|Mh5mx6 z@yYgjMzbSPwd-#LWk|Np5~H#Q6a-6BJEb6f5mT(ao<7r)VU#Pnr=C#&Q%pS!ZYZr>8{> z))DrUnnRD(h~G<)`5z~#gegpla1YrgDwH|pR=*n?=0tTCWo`8U8MqS9IYNY^g0yLf zjJT&{M1cY!tU(VOQYHAF>g|iOPS4(=8X$|3u&I^nEwq3wb*yp%foHa)VmpoUvW3fb z7j{S!&@d!k{NngI-N(k=VyfDNZwJ!@+SpO_a)|(&se4ofRH?)z6?hnn-=YCbhhMzM zM&k%}ywCiw5PttAATsMh+ik{Dq=#EK06RA?7^uWSH76=H{*`i3v_(^Ss3buhQdIK1 zcr6vIF+RU`QMJ5!fms1dO4K1hrT=UCza}O^#kM=Z{{|J=%`H$B5-Hby!1yS>yG9I( z;-bm8fG%za!ou4ehrn}_ZvQ2+L%1uEhMyG5oTu)fhXC%e269lC2YIj0&KnB z7ousabaDYL@FzOuaNYoJcq&Ixi_+ZXNcr!LqciRPz&||pBFWc#SFVX90~=6My*fo~ z#6aJ#xzA&nfhD=Q&za~+JuBDBoC5Vbx81+;2^TkZKS&fTo%aXQ0q#-<`1Aw;AHQXw zVX-@!x8^9D#xwM>WcziLUB}G!>-nNu+61+Yc@VAE(+7#&4|L=!dd0sqnSpCXT%in7 zvXd+$%PVf2739-$?VQJzwBLEQ-wC<$mZAH%hf)E5zo&`|^@R#WtF#m5Fwt@b&DAzN zy^gRtG4c5?`T^77?)yLVn2B@|%zMvJ@2y?9aEku&LZ1 zs0hC>d9EV;n9eMfzpv)oW5r^_fbnp5(l^dvfrg4K{=(xL7oiDAP;rl8u5QY!pPNJp z2y2E4TU?Q*0pTRf+*#DI`fHxEIsuEjD^fcjVoRSrtFdoat!vf&=VUc|(%^>rL!+i> z8?j1!zFGBvlD0}vVV*E@(YkhE)-m;Cf-us9X;VYrOMfrf|D}tz<*=`HIey_zYos~q zAwe{cde?t+!xooxj}GrPH-nj4N6f~UiGkLiKfI1aBV=6N)(1B4z99S9vBy6Q4PP^$ zndV`q8ff#dghmp-p9N9092k*1W`$I65Km2$FgH-Xwtm2-IIunWW8SV4v0h&C3alpz z!@Or9XDr&+&HTsMML#Z%>d45Nppg2EYl7u3e72#p^a5qG=E2M( zZ>~Y}2Bmo5W~K-8gV$C1x&s}(qS-hEGJQr}AtDi1*UJ6ZLx(iQxr=F>l;ep8#!ii- zFrmyUfyp3jFWInmLScCOB=e_wfunKQ5TI&=I?Ggviks37cDe`t9+Xs)^fR{nsmfv; z;FO2|;SQuV-(4uDdL%<*}xS96Fd$xq9jb+(q?{c|??6 zA1fGg*W$l=*%rQnikp{?HYs4TBlSWCQ?Y&M39~!bS+b1pb5??4HlKL-DGH9fo_K;s zfBThjC6Y-^_fq7-Yp=?A`-d_O`f187X%dj1bKh;8Gv0eAevfpprll@k#-b{othWc+ zI+o-GC5aijqFAqpNH*rv8fd(+3vXvgQ!+j@^qVtgRPU3hEpJL0<6_|{!8>Q< zJl_45Uwiio`+5w1Yp7+RaIsJuLGTtt;y6=$TM{H= zLfmRUMax>RJLRH#dWT0{nKb*&7~|tB>s2b=PObW}_AJB#O_K1Tw5`V}lZ& zkmE0hP3WYX)jfL5-wsjCU7l^9UoRGzxXJJnI-T%OD2<=sEl>`qSa;(4UfTMrh3y~> zhhSE*AuvIB+$BO^ReHvcV*wC(@Hf@Ry1yz5eG}|7wyMQew5{&7(4)HzJ);o>B#4;6 zQfArqr{=6obITYm#f17zCAZU7ctc2i>j4r&PJ+<4Gr>Vo75L3fJ>4;NX6CYD(Na^f z`#K#yaPfvwhv=}3=(ogX_of7n;p$iEMqu!sk->^M6@D=8X7ykW(x(wCF9pI03Q4~CX_A~sM z`lYJQ5-I>heO@+!8gg>Ox@v;zu3YM?bXk**>l_51-(Yh3g~K6?8WF9$+InRcrh}$F zy3!v&(32>aG#P!5u}#DNA~w*h@eLI&0lak+$50q;t{*+EvK6MltW>c*S~Ya?bKz@n@z5uwDo56YIoqq42=Mm#Ox*{7_rf5!B4ScYwZZ z)&q8ehwilWtHPKFs1@~^9@JL`tJ8mjd{i6!UW15#uf}>k zWydSUBp7{;*!hj7X}duM&5@8^uDCQt!F|kd!O_yI2|72aG`)i5#xc~bNt^j>lUs2C zbE5k^Q%k;`Fig6U$=Efr*ZoQyJ+Wf!=ono-XK)B<~{pdrE*e7GOh94FK9IaYo**N`*HK z9tb9GMt0|arNW5rd|EBZ6v^;d$5D`RnI;Ebh`u~ooGI={G9Irk(zsi))wk%VJa(}6XAuFq%a`W4rGe|Bo$GU#8+M3!IIbB0PaftBlUlJKqgo|PwG>pur4yJR-W zCu#F~d=i{4L@yhZ>|;Oj=!(r0z7op|V=6op0Oz|9_DlZm5i)YhP$;~Ts0Oh=C>sf4 z%kj<|{SB54ysz%>Cwzw)U&3MtM+vgd4D&ojg7BGhFU2ZN`GkgAuW&0ioh~(uHkIx? zkzP%FtaiAnM`sWAg9P(aq`N7dkOmGN!drDSNMoJL^0huSQB6p&tRtg7w{3ioj6m)ZT&ueX^DW5(Kp2XvTV%9arx z$djx;skhd8Owo-KS%2?XfWO73mE{LCgwPypOpNVP03JsCgcbYl#|F*>eD zxMR~ImtduL9I6*%`}5DqH^SUMSs+H$36&Dl&YTt@3JrI{{^aq-~0do literal 0 HcmV?d00001 diff --git a/tf2.0/docs/img/FastGRNN.png b/tf2.0/docs/img/FastGRNN.png new file mode 100644 index 0000000000000000000000000000000000000000..2357165e7ed80aaedad18b401670f7e80caefe18 GIT binary patch literal 13187 zcma)@Wl$Vlw6%fY1PKs4_#naEg1fuByF+jY?(XjH?he7-g8N{>AvoNAzwh_0x ztcD}LTD`wMod%-vIkXHpUk|H2HHpBP&*Zr34JY#Z`7U+1TB|mi$*LBJAS5#xib&fsIV+YcZq)|bv= zzBrahxrEC9nioUJTe0PNM%V27=DDiZxrgu5=65@`{C?YK2Pt%A7=p!{Sb7F~SW<$~QabawhVkF2%sq?38v-gwJ zwezOm$J?~U9ohoaS!u4XQZN#>n!1i-@DIcH%a8Y&RipYMeTN|o_ve4>_Np~H?OAij z#3e9TX2d--HCO8F#2=Au{9SxFLu{N#q10{j|Hc*`GIQSXD9)rnA?dDj@_7f&-qIZJTh+=`6Kc z?W%$CJ+CUJ9f>7cN!9(CPWZ#Uj>V{61Ks%F-&HDFPhcueOB}!6Ofamar9Y>e=QxU? zhgQ9)RcoC96OV_?|4UPdXpsLwqv2pkxm5CZuEMcE6Vdx27a7T($2}xh;8M~!3l+gH zCN%6oJsvE~EWUq&TPjX9QyQL3WYyQ%a@PISre`xt=KcPB#&Ey1qF>4D)ys%ArT?AY zo=~+x#8F;{evvFeN4DOj>=hXE3a7*IgrrloBlgO9&2wJ4N1`2{Yra36HhjI9 zUC~JuK4Q=ik@xkj@O?RNbxr=R+T;<&>1T?NvhB-LJR+M`gtVcfqo4XFrd#w_{fyOk zFqtjCZ87cBiSq5fKM*e4Efn@)0*Rq(rwV2UC=+E9DRgJRTkG2Hl9&6lqbiWBLTr!a zLPIB{AFe_a0Br9f)Sr zTi9loF`;IoRi);q?)wVcGx5c=`u(Lcrv0Uosm#T4<n>oM6Cjtgqwk8jiZFsL3HV90|i7}RAf-&2Kh z_1YY6$GxEEt3lN4cyb45^#%fN7m21yRK7=n%eCeST;%8u*RzP0Wb_}`Lm1rlTU~2q zv(sV2(kMh+1-D>^D(8`}i`JjvOF(H7v*%ZypTdZ=WkTV?;r$> zBrXZ_DQ(}UL+tu9hV}z=^PxS^8{B85U_>m1r3zI=-Hv;-Tq&U~ofe^~EEah->v@H@ z=XL7)Im%ovZCZb3`z@iljGSC%#bg@o-?+n+kFf``*<3zi%PhFoMyIzJOR`V5QA9q8SDrcnBi$M zc4>u^shx83>7pu6RQNHo<+#edMd`HFWZ4(1tA$Uhr6}i%kKu_+wY?!fg^S=MAa8>4 zeIG{^D<&{vE0sVIJcZ$dZ>&EjshVBae>Qs$OPMcfmrWE4*D`GTJYo7z3*DTdswkCP}8-glaubp5`R}}v00V<{6?s( zf3aSh($%99O>m!ey4e`&S*EA!G4ly^QH?StS3x|6Fwu1)xJtWKjys}C2baMnhTusr zo6}?7qAc2UGgSsmI2pH#p3J07VNOoyoA?=*g~}LoqH( z!G4=>to8;@4dsi#0rXc&4YMRD5Db^**jU1^dAw0I5M91}E?czoM{u&9_a_Wy#Evl}fzS9P4|$EE<+*3r9cFclzv3b=P&w{0C0wm+KA-kYTywpgyxtT3HOneKd5Zne|B|BY|Y zX^o9u*JqRv^6`4apHMGH@#H+iODPn#%KbZ*h@son9W}iXVV68^HR5u7Y7>veEDDd; zvq{44CnV&MKOvvDM$5~d*eiNQT-3JX$hTAA*= z=pcKV(i`wtwO$&t;%^=BWOnT{G7c5?m31U`*+4KNxAvaw1Qu$m>PydtHGPX8pcf*& zY!}`~i=Tbi7RPdZRblMLYw7p$z1ZKm#TuXAj~`CmjG2MuR{T2k0+ciM0}X}XN7+BW z`jH=twUI=hA%X~{pkDgB0ominqdgsZ1{plC)PjUXOI4jYJ1~*-v@!y_;Ron9Wjd`U z1>um-@$fsc9GgXY?z=E2KgR>(#&P^+ZQ@YM!vp%HP>egJ?xx^B+oDL}pM2l{U9$ux zT>_*|{3|6H?;&qb+?&-;IebA^`VI*qS7xKBaPC&S98z>$M!3vlA-m)+3FtnI6;%Am z7l5?#4w)c}{ZSB!N$o^0HEO;NY8xI-3shobbUjtkBx1G+qSvd4>iX$%l8oEo5RdNp z9w6EwzYW=;F*fgcn^HZkRvqi>Cv!BWuTiT}ueAK;qwlF_0H_)#I=8=n6fgXvGh2?^ zGVPW_v&yzt{iesE$i$=PknuPcK3=vzOmth)z2L@GyObYJI0W=qz?U!H*bkSJ!umId zecRjZ`x_%-*7Gc8lX3N}Tb|OfI7~u%F9W=I*w~ayj0Ut!@feX!ahf{&l9Z z1j=v(7L)u$B3SaJvPwSdCcxA4N@1y1-hK=fkNwWlRsOJ3H4$9BSeiS`W8sEd{Sq}7 zHIYI?bN1fDz{}9*8mpW+jnGA}S5>9~>f{Tnz3e>Ri!_N}0rS_`Z&9AvumqO+@)yT^ zd)dwMi)ZrNKI**633`0t>APVH5Non4@DLE|?8OuLhuYq^SJQa4n#pdP5K|h8qI+yD z>{TeQt^`e9S0|b-2Jb%zd3CsRfwqr~$EuLc3mybH`ry&nqj4lC6 zS^*~#b%Becd`zc($}I6fVBk_}=SA0M8YIk8pJ3q9-}c%H1Psbb|F@TNtf@@SOUB-@ zwz&$!w@{VJW!oh$e2&U~;gB)UEppkusd$YG=+dE$eSe9Hk9^GEXq`=Js2J+^(dOp$DAp;c!ayS!Nq)slB#?kg|AJ}6oI;~6hd(vGS~2Xh-a^MWSEpX9=Xtj5 zHx=Tb(dEu$KByVVfe<}=SLviRV7~tE!L`O4#`U5N$_SQ1tP>^nL5c)E;j9Z-k)*rcP-a&^9$ZNWj+{`a!C(Q+{?2A zVQ9=TKSlBJ82V07-^sg&y~a~%3O!xC2D&;Vs&#wy&TkVr8a>X`nm_s=lXnw7Tm5w= zBe$PHCs>UR;jtXhXtsSpxi@$VIo*6kg9{M22+U)rTP}R~_;tjmu{X;c_9R-e(-0cSJtY1YB`x>`!G2CQT zO^P(xU+n*Md$n3<{H9NFtu!n~xwLVYS8s0(nf4VX%;wRD-0V`t*X^|P$p)PvKmY#0 zaT2QCy0Gni8+~ghUFLne>GmAoa$bB#;WfR{p)<_m^0a=LbVD`GtM#Yj{*v<&CE9_N z_D`pqeP1yz=YjQ7MR6*fHaN;69%-Z1AI-%pU7qhrWd3?A+2Kvu|Cb49E@?5uU`?JJ zF1otX8epTPOMckGIxj6=K`nixPu z)t=@8gJMP+6$hPyq%WawLF!5#luWOyWtIwBy#khs)W7cD>g4h90deRW{>lFw`kC?} z73{&qu-OEUjzG6TzOtLc>!nn*L4y`n;gp?7EZg)0MBJZf{Y_%S_hyXp(6!ASRQJmU zB0;j8JwGB?#1UQk+V}>4t_qc({oSS~Rb&$^ev3mPQqw>j?rHgS#!2%g8?ZunXo zvmot%7|=K1BjQVO$VJ8W%1*`ps%Em(Y3eEH-xpy9^rEFe>uS{Puq&k5D)+Ex3Z6_~ z+mP7cYG|^UBR%Gm$QQt3+kW_C9Lq%(54lq4@oT(ij%S*N@9j=`6DSCdG5G#e{gtVD zo@1+cguC3E$`}AsF%(=34d>VKhZ$otL`qGjQ6G1cO=IlG4W|uDTM%Iemd0s`$E)<& zYHbc%89VCT{J&F31(F8Aw&*2ndKp1!O_=V1sp&VO5t#AEW%lgOw1Oxxnbyg8sTdSg zU?SgniC{gWgRl3eOVY?hd>gpPM6bq`>a{6n>ok;pQfWcdqJ80oCy;;b?4WFHU(v{m zLmJV3^^J(rf#R1mKy0}lD(}K=y7tO=YmB?j1wHJbFehFzm)W`kTn1U8gjHsf(H+Zy zec)2MwAic;*+i&Pb?+)kMNfp|szrgTdf%U|&B+}k6tdWDGv4o4^i4N^WF342G5SGf zOAMGEjK<%rhjLFO@gF-WR|#lr1uEg?K`^WqQ$Zg^Yj;`hm#(FeM(vU6dn97 z6Mp~jFF}FLl3;s=@1^Ro=XgxSmsq7%PuDZHlQ-wM8TOIri=7_YDW5hy`6-NA3;`GU zMB2YeYDkdiB4@nl`M^+oQo*dt6gs~wX1rD{w?h$hnupy9iu0E?MOkzl`|f+SW?;^Y z5!PxoS4166q+SM?C8ab`p2A?sGc=s0dqY8@+bQ=283Z`gzX*|V^iaPXM8P!@G*Yt3 z4k{S(Kk;sE$KV>EDHSz7rgS)bbJR&0C8h9<2U_ciserXuewtaN=rNjvK%~r>yRULJ z+mnb@?*13BShIQ%{)CpqWfF9|9EC>-XYje)t`$0PcMgN;mvV?wNcv`$4I)@fBy|eg znQYc2V$pb@b2WnfkF=c~yMBOUNdrTF(_Xkq87e-LQ4nS|b7G8_%_Exm zR0asjpIjt?LUM1I|6sZ+F(r>Fo_Op%Vf0K6JrZ{U)79l0$#oi(Szc(nHK7>##bBCq z7GXGUb}A!l%!2VqOnBoRz`o+Q=6-Ocb6gK&>;JINA0#P(Y*c3v)!hFbJ#4fJ0419__3*7#+JC6C& zg&?n&GaH*Uod4rlI{Wp6($z6U#_AGVpR@4fv-9mRCT6f6(BX6#4M!ZEwnxG4xB<*8 zl03o*Ph(P+VjnZ31IM2tJOWtVT*R6#)>~`IO9Pllrq6QH&x)=za@XGOf z_b?Z+HfBF&Hr+cq!Uq}`LVn+_>13MJZSXUIjYUulyD++~3# z?y&LX2O8~(s6jTbml}uVveffFhvn8dZg-q!Z)| z_aj)|#V>9o$joxm;mZMS(lHv742oyaT;ey9RGcPgGdY=p1T>jbNu|123s0H20cO*2 z%y3#SWIQ*g#(p$eQxtirke$xUn6f~9zhW&;7P_OUMDeLy7HQkn4%PePRq-hwSO3$6 z)?Pe5mpXXXGVh0M@ET7w-#a;aVSu%HxsTRxVdmKZNwI5{_Cc;f(ljyyQ9DHedL zIqN@%zp5X%wQ<@8h2E}6F0|Nc+d@}1{<9Zho5e(SE*yVIyojfiCtPLNh8f|18RT+& zn6q$bxCA7Oq_QQz?Vuqv+Uy1-TdzXDXf(_akc}iG-cFr&sPtxLAAZoj&e$@hzzIEs}*Zr*p7LVnjwzH;8v`HlHChC->fKxOVtYfCrDSLBx?0eq5XY_QdOmH`gwNa~Qxa}EL zy7uih8`X*a|BsjIml<%BqYiJ^80tlz7TN}-7|1Hu5-;SY3~=@_buzMyO_!R*z%0bc z!GO8r8YswY$ZzqXPEzGg5LeHdZ2D~m2}Yoe_&NmopkZG*9!ttLZ*>_FsO5gk4C~|% zLYaMi!AGvk)lpUvRlrNSdemrRkJ?E%3z{b8vTGAeruNC?zJH;@; zo?Tj?AxoO+xac^8nuqUNth2bCbYuCZR%h_Do*b@SsN3_N+ijZA=T6aF`X+Dcn#1B} zf1KV;-CC2Eih7w#t6rT$cn_aBA+P-wl=(!j_toymGnf?Qdg=SgHMZ3A>^}z~G&%%} zcCN@!v$PZS$6*>;N@+oADNHC0fbK=xs5VMUzLa_aW?=T*Eb+}!i7ft{L=>8f*r!j? zQKkIlSRyZ7Iq<0n;*NI-6ThEO%szfc!@he7_Sq)O(l zvH*?^E0H@D3jJxDMR(+Ua}Zl`V0$>p=2f5ZK93}(jm$T1nSEg?!HK7)~O+j78<+iN6-VC@~)s^t&Zs zAFba;(SIHTHFaF+xk=SG6$l02oJOx;Jc&IVOg}QBK`pciL2{c!gKaz!FZCGdXVMV9+36-t#NeIr2u3%OVYh@D>)D?mIZ56xpe7$V z9bjUtj6u>Oee2ViCH-=t7Z`i#&;3wI+3Bj7s^g;e?N6)v{rTp?E9|}M#YTIUTQMro zBoo7;;uh>pMk_a}(V3Gw5fA@bgvY~5h>sZyZxQl9(!AjH`j-^t7ST3mLUygl@OsK)aCOA`@o&sPro4@*F!Z=VHRMqo|v zykDOA2{Z1_2;q>_<#?V_draNw!RuhRC z#f#UM3Z9&o7Cya{ul>dp8FIN;ROIFN2*}{QHNCH#gVN%gU$*$%Kjk0Rvg3+LuO!I% zlHvJSFk7LcJPlKgxPW}6toS)bGjcV;9udw5u=k6rV_rPc^i?1rZei*k19GFNuXKqt zr37QWDP&2^5v%P1_+nI0%2CxP_30lGXBoGvuMn?FZB7P{bqAC1&FEw9Z<9ImW+v4= z`zEPYbrquXjS{(BzCt{9OL}8}At8!vOe17JEwZt63r!228ZL2@1Xeixw9kR~{3M)f zXDhV~1HVwP>_(91xh0-m&BN)nTdFKHdKksfD|_it;pMJ`7F+{iinQAnlG{t{JO_YT zVz8HVbVw};>rDKa^|CO>ZYhIpT4LTIaOW2adnY%ZK-&(-%~vQ34S-6c>0==oT%2#T zE9-T(2~m^4lbiXPG3&_;1N3Ord@mI+(wcASKMd$8Qoj3(jU3gvN1yRX+XH9?mR9dB z(;mLN<_gD; zzzVO_k20VxR`k9{U8+eXfLqRa7AF?WD+E71ONBh~gAtGqpri44)qEaz;VLv2#7^gB zQgccTubK`iDB3Y(11Y&les25T*Fy_sa}ky&N10f#hy4SaZhICqxafl@qYx1o!b-G+ zXzG;Gzg=D`Rj;?GuizkuVO-|;+6^>*ycWvkKCSD79U&FDH|~)(pKRVwOYAA2$uv;R z6=@U$f&qOl@<2NEm6bF+E$w}BR-qnh0m@$e|mJAt=F}59fuFG;aQ;x??YXf`?5&lnL&v-3+Swp zGIZ#ICr^Zyy(}lF2R2`9O~51!l9y1CNc<&XA$G{^`-rDpw_qOYkDIMdd{eXS%UIn% z;dz4l;EsiT-K^#zE}!QnyR9y+c^Phjq3?E=4U_3CXAfImGq|$aQJVD5bPf8!h+pR^ z<#SwG+{FP^?BWjvQ#2T=E~4BD3oum3kCXY$v=E-9jwPDuOvS_OdwiAL_rS%1U}shO z+^d(<=Ch@Jc^6e>iWudk3fSdK_~^(Y2~yp9#Nib%B&e4ouBGouCQ>@1M-|iO)9U-h zk3>CV+5awib|74@62KTa21>_0Sh5~5G?EfGG@ta6$wuzyu{ac{_$VA*-kuNM^yWSz zmt{ zo<+E+Ecd=XM^@EpHL77#`&QQIHMYD1eA1LwA%oqvBJwK@xJuI~_|6V}?MD?WROrjl zwecbxk#AK-tG9^j$+9q=PM1bL{%0a4DIjQ|)qFWCS>W4P(n}uP(*J9a4}^|Im(nz( zH)km^^-@v!=Ebi{r`@6u+VdoK&VQB6efxlWzj`qfUDqxWjn5-ht>4I8;F(Iq@2hhY zajYd^RI*M!>((DD0v(Hj!g@m{K529fbv775+>-HDLUI1~jc1?+!+NnSp0n7pa}TD8 z(9F!-+sK6j0x|f6R7(64MMYn(h&ZHH=MHNAZw>y1R;22uGrwdwljYjkLAAyCckStO9 zk;|9S6uQjDLm`XN7V8!I>jwwt1&ZO;K(o5ZTz-AeiiL=vNG1)&U8?Kd)^Ds9aV)3E zI%GQd@Q%I=(mPjIma}>OuRVJ4^fE$=&-A(-#{g^NAqyA&?Fw*>`wTYgtoS4{f&^T?h!MBdX0<@QMyEVlAyqTe;~R?D!dHw*co_hZ zQpjad`SIAUv6_?TbbCF(+H7?g7|mp`NfSvjl3fFrop#%h&}KRe#1$nF@5~NxA@d>p zf-&;Xk_JrSAUI1_D06a;x^`-9M({6|1m+-`x+sLwX!e7sAXJPD3d%Yipt~gB_S@-$ zG$(f%U0?#Y(W}>Fu{e=D{*Ch3tr`yWL-fP&(CJ~ME}Mf2%t>3a=-_s{*zl#q5arwk zob2fg5LbmtIVsfH)$ZVQ7{`zcA(o-7H(i={zf055j{;=6Zy4Pw=EexOUZWp+k28$D z)L;7Vm{AeU@4O!g3&I~eHX6{MgIKr*yF6}9zFiqIZ4>d9GOC`4?36IZDWblXn0T;Q zBk$XtVVkY&dU@WTg7(u`=vk-8PTh}ZNq{s8suriS9n>CiK;-J`$DN37_1$r7x~(1|jD7>J4fOByP% z2jo^+%oUJMKCX`C7Z7m?Arg@{Zq1^=gtlEyF8uXMXEq&zjJ+b$h5R7oKwp6>CVf1; zYA2a}w4rd#6X*7+q@py``VRgeYwNTBbSG{8oqny3(LZp#eE^+bo86MSh+&5k?9GCV zyOOXnRE282I}p5<#b{3Fauv9YptK3i29xWqaX$bkuE86uqBrMKNK%qfq9x z1bWqqr0E>q&K@8bHY{%n-EFtgR&R5}-T(Q3qc4hyBO_}6k~Jfs3swionr2w8}1 zG>(4ati8!xjFC;=>2gw#!k|z8rEf^eJYcRajU?Sid{7D@OdtU)fQfz^G%3N z5dx~7E+dOayMbr_rb@Gs9AU{yI* z=#st`gG>Q=!u2QrRo|y+iSb*2F`NPRR7A;sgVoaaIrlJ?4u{WCxcLRnnbN#Y;6LWK@eAIzz%osKqb>_Ih(2B(8X0#3AwM2}A&=}hEqHj7 z$j(fLe|llAP~3DLfnaA6F>#qV?kQWj0+JX^Iufo-c)PA{prw>+!Hn<;OG6|y>M73T z`aADXdwjaqSeo=6+^#DSgUEKvLghma5-KYuL)di=(;*R{YCD`o=lKB*yM{yJ>XxD( zgt>;&gQvHDE0{+bG}o|k#p8>E)kGVF0+PH_5YLo=hZ(dWdhC6a+kP7Z?wr4+OopN_ zVdQ!~_e;*GRM6Bc1YVFFFbj^?jQtE@c`ZeBOecf?g?4f{>~k7_l2FQJ#f8BVt+qRx z2-VFaFdXWcPN645j`4y%^?njn&xaNx&OWNp(SrfARd3u$W(l_n@D9ywMe$# zqskres!*@Jh_bk5VG8N}?A|-q*EBavJ^H z^uP&j?gdulMCU>C^w3`-hIH|QGS78kQ1OL{I6lvJ@hxO&CIBE$!NRdnpTrX)R1g|< zb)qaEa=9~t%)zu^o~)9=oZNoVkT7`zf{UOG3BEXTPoIu)koWPlnqhGwHekDS|#H0VE)Rp;EJiZq#=s^!vI-wZ`J1 z*#~K|y3~Pz6Apovys~c-VA190R1Ni>BLpQCzNww7OrO2hJ|I z$^$N*cdY^%9V8Qm54u~b3oEMD>(*kGz-A>sC0&ITyULzm+e?eWrK_G5H0LJZ&@ejn1H7UFG>(U=!vD|Eats<0*ow+9x)QfWzx(%DQAZt z7`RgKC@Qy_%z+3y5@y+y4ejlwv?QFrJ+!=#W87Y4!@1T2VWz@9JK`<2=aB$WaM>2f zk<;h?uIDX@PefSAkE2dU?oYW1-=MyNq>|90v2qFc_@!D$tBu|Twxp-AxXk|FIX4*W zf_}^+NEs-m?5F5KLvhsi;-}U97p7KCq1DX#a@p5J#x-9#YD{bZ)!ytr>EwL6NbT`$ z0!LPwG&k*o^N@O~A6Y=faUcjG0c)3HO^Y-Wa9O0$sPW*X4C2dW55i=?0>n@uX3Q?f z3g+Z|QbB%!O@T=evcHb?tPp9N5IJ6BDFjieBSBhyYuDcqv*w{ebr2ZDfRhZ#;W{{W zN)fM=&JU!cHI)!47NC{zfmpwRLJ%P`G+o`V4rb=1bVmPC11AH#1OZ8x@Tlt~a(a5X z0;eW=1Q-tf37W2PrbZ1VC~jMbFS{1RRaj#t7RH406tD=gX}M931nxmkV`(Lfk3bwI z&X+IArU;hJYB8=zq%~SHk62;MsI(O$-+d*X@ViM%p<%+BpqQJ}@C=r;m2i2|`tJ}p zpCXKuJ|_~&&@6oDA;87~jUY6}6Uw;o?)WV$BXN_*VS`h;pb@jNM|;F4_C}B%7O!w0 z=gX$Z9>SQNhao5?NT)hUQy8IBVZ?pST>@c zgcqdTiGIAlfrx**=Hl7fbP-d)V0|Y1p;6!}L-Uo6L^~4HywfrQ)h8=Ft~AK2;eiZr%``IOk0rgBy&L3Yz^WB4EPl8h)cZv!?uz_G-27z7B;kyrhfv+u4- z`{SR|+vHxaG0&qyveQ{`{^jks07Yg%<7iHtX!tL$dY(;vk@$aw^g&Qc79tUEs0nof z8}IMjR%(#?c;M)_MveJTi?>@q(0Z>?sFRfVYxPi*Pc`T%O z18UW3I*V&wFmTAUMBjLCPlH$&k(PJ@JHOd*5SCinAwL&uwf8kwPCgnGK|Gc&&tRN>snN$+n3|i|zY&euAxA~$GnO>@;Q7R0GWGJYiV!Xi z@9wCejE&Q_1;r-Pf165pHO!;`w|yl4$IT}JN*|~X85?CP4uo&O#u*4nQ8|$sA%lSb E0aexfMF0Q* literal 0 HcmV?d00001 diff --git a/tf2.0/docs/img/FastGRNN_eq.png b/tf2.0/docs/img/FastGRNN_eq.png new file mode 100644 index 0000000000000000000000000000000000000000..9df4789540a19d0b0605fb3da50e6bacfd8b14a5 GIT binary patch literal 10509 zcmb`NS2UdAx5kkmh!8b`kmz-yL@yyhMh`I83TiY>HZHBtj2hVfx(&w zelPXOLw_#|Z;4!;x-(W>=cTC%zIKgSis=eNI7p#%kelhH>39N2A%b&AzvPdEP8e~X zsX|_Pyf@`>)}8aX^UATZd&c5b``J&CyA^vMzT=URHa-jU>|?R<&tqx9uOjgtV)^46 z8VM2spGX$g^0)M=JP(7(A3bCb$*scjf7?e)2B9Y|c^1%({}7A(ovN1q+g#&E_#DKK zK0RjA)Y5Bq;fVybH#Qv&bC})4y?eqFD{)!EiHqg$rvfIQnYYNX{rP!|)_I62C6%M1 zv757YLfX7PipqS&`^@zVaLu>mOnd3a$K#S2W+{3z2&&_070I;Af4ePpZ-L7y1LycN zUo*Pbu)+=hQLbI1h=OyCcHz?fspn2@y}P|KKV0oHzd)x5e%Ws-p0{h+TY`5#dt;rH z_Sz_*>10YFi;>vkXk%#8ANR#;1yEb9)`__1$?M(nnwUkAyT1n=cj>T-rhd-m-+7xE z_U$|hoYg_XTba&FGp4OKr;D2vE&C*~C5*XM@_m`Y&Uxx8;^e_>>Q|?G%W&)rlUQ-c z@OoY(w^PFPHSh$J^^11YwpCp?@!l`b0Trn}UqQeJ@St38cJ246D zoH7JQCHXN<2}h=vC)@Fe()YsOnOSLu*N;#j_k`>6_RF2TDzDRJ_Z63Ux9zDuuzGHo z%i>@>Ll(cx+5WPO&)+^uQSkENzgYvh;!G#`6JV7@z($X0gkF1d>DG*YIcQ6Pw1YaX zk%_XCVNV&R2E=Yorj?;Twd@&pW-D`a6mVu^c(*)rJa^bwxL){x^ov}jKEOV7ww?Wi zmCIO`oBd|2>sj%=+k@k;r<%lkUE^6Fh;>gOYoEW}%JFS3Br_R^aaHJFa+y}LW5c$4 zaI*75Vbe9Z&S#`J4tBz z+fTg=<+WjoYade*T>i9qjP*4Fc-Qkb9XF@SYJy40-N-jzM-XO4F=9s_ti(B4RJ95W zum29vO&K2EoUOEQv4xR~9gpyBCTr?DCXlt#vWAwuF_sfRDRxPI?CZ>JJ<0Za1^9Q%-U+Y;Wg-3>ndN3A zN~y13A|v`}wQ0+8^WG7Oe=>fVIo(ffPxA&g!)d+TTxT;-SViRhJx_)}*mqe)PtKT` zCU%5<&2l1F>f?HUG}C6lAGnTP{7sVJPmgySji!wsfM_c!=j&U?q11KoAFSzW+Ns{L zR>}}inz^yfWDvXhQ-kywpmiSVOR-BMZeg=kt9 z?OK}ZSdnjP2UVFR1|eET!M|IN#}Zv(t}-aj8pYZn=8X3<^2AiRT8|iK%QRI2whntq zRR;h5eMX3Wk0DHbb9I)oo+<a|UO?UNMmA zcT(!74L5Qe9Tz{#OeB-#!AEtt>FbkmM-#EH%Ww*Mp8laiwCp$AaZ!2dDGDyTBo4YQ zyJ~6|dqo6CM6XmqniIcso1sbS?N@$kMaskP^FdkozQACR}|8r7c*syP^x8}np|y2D-7Xl8Wm z#Z3&@wKX018kU?C8UzGI4_a$(#tx#vr?+Q)P9Qnx8r z9h0)u;ZZK3YDbfDbW}RvLX_V*<5~Jvw=4)bU&52r;F#^`DaFKN@?G+IJ{YL>eE7{gA`OZ5>$`~X{Zu<@d!@~4dUH= zGde)$OJfV9o2i8qqPq)_>99=f3vp9AC>~xSYY-?d@h4VNEJBr3Tel$mklOv{=i>WU zxhVgv@>lYDj3~ElFxmu2-~H+3B%I@K=0i2hpX}~5G9-5084hW=tCVL+1>%5nv!J5N zt-+H-xBXrCA6bcZW18!_yVw^D0|`y3P{bUOtD+DxEg2?f5-fx$tJE^0enH;U@rl@B zr%MBywF&1}A+6-hGS`DlM!xG<@7v9EAaaF;QyMc39R=npOuyK5f?#*g=sHHo%W*vo zH;|^zM&!x!Kndz??qB#DcsJj;-@N-xPg1?Z=Vk;b6A~$zY?LPbQSriDB@4#*3#tLh zMD@m`3KC#L1Jt6uD}sx)0jZ%i?{PH5g6{L)MKtu<4xae20eK@49@Rr6DYo?g!y6AErzUR#ffM3yuS{X19Ru z7$TUg&=YmE;PaQxuN#`_L@F!1vO|tfe0`_X=J~HcBf_%lQwYb}}@ND)`8t_YAAHJ)jWTfsTYTQ?x z-9?>XujzW1Tzvf5M8Eg;7y940Yt5)YViWu9ea#npQjGvKhymwAL|PN&_d$FRCc2A2 zdNwy({7X7Gk+}(RF$s&(d~|Vm9Pt~z5xBOYpIb1iWkSe;i0OA|B)WWT{n;?tSi8H{ zYI?(yEtHWscIRhXi`87!*F-XvDdgwdi>>L;w(ZH^^7B=>5jqWUv(?LO<^gBFH_GV) zfQs6#AB*m9*~ufS`eb^F&B}KE?^M?mS!C7hDsO_P*zT*C9?12+0n%W(YmgyeTOf+7 zNAn^4t*qhzyBhdo?vL0;zdHS;c(%FN{g)D+Sw3gabZWv!VwK|E3=xfk>`E+PBR$i3 zLAu_dxmc^=BsS&r2**j0vPE8_V5La~>0xGwpT%)Nji1E?y^?VRz~5v(eeI zd&&xxfK6&5W8g_(ffeCQouUB|SKGvW(rFlGLsDc6l4Y|PXVuPOiHnx^(RyEQ7b9PU zroLRxxqfD2`*#6WnJf?z-10V(YJnyt3PJ$Q%Jx3h)|L%PelLq?#{2%WuRnd0T+up# z<>KymOjHG`YspmuDmDE$Jq(Tx)R2mnzYg=Dpz*!_yXL8evf_9NAt>UXtsHV^wAT0j z_rq!Vaa-UItW%@p#%VyJT8{~>KaR&7iXm>`TvOHA2gFr>dkIOIs?flF?JVQOYjL#n ze>FBByat&}58Vyn=)5t?!;b%g(@3S2d9n0e(Se4-nY&DwI!TaXLJF6bk(;aH)(`W+ z#&uZn#a|ZaVxvkBL29Ex@2Ilmjf*p|l0!U@mniWEwwRGQlrzfpY$+K~&c=K&ihSJv z%h@19Y*^@91y=O;4qxBu`kP5L3>1mNr|}M}=^~^-tC{Yvb`$^0wnx?kBwh-3Mofif zDEP+~x4k7O2DHwaTzhs5kxWs}L-M*VL>UW01y^mI0li@)UpxcKU=C30HN}F?gJjH; z^A^h@KG#*16THyw+nKS(P4EL-LxARjVMchCY=B*=>BX*-D0okL*Gs#v3coA6;rnxc z_CbqRF&f2sj81+`s*igUs)E+)d`i2Fv9iVKpavQIR=l%88VPLrA;CN2iP9VfN+%CS zupQOg^jEX#BO6qwA5{ao4g5|i5Qi|c*6k&;t|NXVMJ1r66qm>pCktrJw&eH^eSjU5 zp`0Z`j#m3YRx5I3Y~b~`Z|YkKwh#TZ&II1>aJ^SgxhXo9t$|M(z4r3!rvulU+ABhrT$D+MXaHus~cwjQV zaZIf;qV04+?Ews>!}RWt-Ns-%NymFU*o^PXxE(2xCln-b{24s1l-j(_VNe+?Z2f5p z`r;)-UPN}K-Mg<_di$pnimSi5Ybr44;~Rl_Q&M-+<&W|A9Uf=ur%Z)sj~-I!{*Ugj zY{W>5{P{Di`LBJ^%Cj!ydFrRv>8H|IEX!Wc%-nXqmxf9k0M7B zk}_1&6U6R!a)=d~wqcpWJR zs*%ELp@bWUq73}{YP(z>?%_4~gv5-zKmCYG$RSrpo{9Z3?yC}xh@R{RuS_Awyd)yh ztRgCt@ahm&kg1nU=tlo%ndGoa2mG#ajM!NBqO)m8{g3+Uir;rm;<{g&PD66}j;zZN zn=|->kr17B0e+4IX>*pt496ZAIpL# zxmERnJ49(Yv9h-Fghn|wqnI4_IP#>We(?y1p5<&|5|F|wY1lR|$z(sRwmXfVh)@{` z;_+eypn--j@pF^TMN3}#9(xLM^p5xG8qN6HEAk1c+F%lqtet7cQwG_v zx-Q-=1&p*+hcXCaD)5HKhLBdbhVZoLUAE!V?Kp0i#MQ6=a5tyh4P4kTB`h(oDXvW}gAei@1S4LX+2TcQ0XnRf(QMIJtQpITfS{N9Hr8e zUo%{jejxj;FD|h9)sWnz8gNB4){5V{*IK?MH29U1gzh&jg^%s1fs)vH!%ViahZ}Q% zFl0x9@;LdPvsaB20X6<;p&hT9CWGMAqD(s&!`#aCBu$v;^G>bafxdf>*xA$%J z2&vD=Cm#r3n(%M4A{2eJ`J_2`Z+_dB9_EaCks8ztg}5F}pWXeb8F!odZkZWHACU>7 zXOhZ!(4WEXz8RS-s+wV(=~c5}t1V++Nx#Fy=JK#T1&;ty&(m zow@03hQ#wzb?EqVsE>}N{@e)k1~xS>awEqs^H{nQYt8-U0F$B{sOuRLED{U{J#%ZB zLCJ&AGd^6~1A9R~Hqhta5M`dzzIy)^T_`QK~5K`sjtgun+c8@2kq!6ui z?@SlRCKhpWW;4V)@ZdQ+rucd##fnng&C2 z^U=@We1jwsqkgY+bm)}noZE9^(On^BxS+Lh&|{ z-qQoDMcl{F=yxRstDkZBM?kc^M6$ch*LrXt3gGnv#DK6t6JLJ~fTYWaA6#N@#Zp=G zMg>>Ikad}jtZ7LkIYjH%wp_m%)8;+aqs=6J61iXFb5qk580st$Vo7`)7hqme(T6tQ zS$0BL2*0lz*U(j{t1&C3Gj_ajTzk&~#>rWkI{u3FYHsW(S1NGn1>5=ebl+Rf{H131 zgKRj@qzu@IR&WaUENc8x(EexZLQoV=j8J39wQ@C$*I}0vU1F!4$oY5(A@y)gJAj0X zTWlOx9Iuk>+b@{)UOExZw8WXwN!WB%dtaTpCLEdjZ(l3klIp!Yd5};Tpkj)?GUVt| zM=m=WDG=%f{f<&}`gY((j@4wmmHox8ZfvBl+uka{MCmPY_O;OH363YhYF^)ilSXLc zZu*1b)fc2fpi}Z!$w3$wtzKde-KpK<#@EaOt7q|?SLFu|e#O0fx6%CQJn^;eX>Jw8 z!$;pw_m>wD$$|`WF-g$Mddclij{tztq$O}FkskCP`%JDAgB>)I&r+^N?mhnrA zHH$mVjTQh;CM7t9;NU#hKQZAtv7Er63pgmyG2GU8tn4^?7DX+noPI%75yT$`WIgO^ z{EmxevI*k1G_YwU!O_d1Bl#Y74m^FGU>faUbV0CDa=jOCcK;(z_bGYazEJt-qDEqt ze1D7lXn}-(RD9-xZqG}HyCPJipLK%nK|Mj#y^Lee`Co~?lo2+Tf5glGn+@>)a*!5% zny_=x1(14{(|G{79DD)v<;&1qT|9h_s!kr4->Nk;P*|Df<$K9Kr5JG32vSR>F zWdQLY&y^V+$=y8Yb$4@ShiLuf3DI-a;B)(<#llACbMsHL1^_+jcFWDe;9o!>;ck`| zXY-P>D7z$#INSiK>2=E=WTdwHvg>F8IB^MnLA~a3&k42QO&-7rM`fvP84pCbkM-P_ z(e+7~T3W#X;O{Vu9{OHytHZ)&Z$6UxR9jtAWG###ld%twdla_K+eMQQ9b0mM@@3?t&loIaxCbRbX^VMDhLs@4&Ac+^i2|dSeFJ^FP zMO?z*dT#UO=!&W~=dH2rFI6MFsP{9TN$9#6C`Y=4ck^P62mpGeE>lzVJ%_+B++81e z?c@MDpYxX@d$rS3v(11-BL^%?31?+hH-5avq22h-l0y&Swh7MTXR8H;Eb-mxd=|=| zD)q&i!cO-VMXf%YD9!(_>UrV}OPdE;K?ay!Qyn|d|3(cq`%N17-k=#ZTF-ud`FOlJ zIxU#I4cxDj4o4vHKp)4h{$qNWKBkipgW5Hrm!PB08}V}TW=4ANSxRH-D5lIL%}1Vn z>;}-V8Ncm9t>yQUkMVURraA!8T6(JC(D85X@A`nK)!tjQ2N0G`b)3z|n z?jXELz_8#RXNwqyB9VLN>#^=~pYB0GA!-YtBaTy3i64w^fDR=7p9T8W6X6YFmgv07 z_RC?eWJ=a=$q?ue!mJ3`T$60YnQJrOs`hhqxA?}z))_TU_7*kk|7}mREcjgQO$*+) zS9Y<({*+nHTT;sSpqB$P%#obP58n$FC>tQykf(Rc0QX`UdORF)<6-1&*Ky+(&sW-g z30N8V+BN1wPD;{+!Var%8z6#S0C;w?g_T!_B$}F^{GJ4Q^C-k(DmUHT5<+JiF8^*4 zzf3Onh0{^%VOeFnXDv|5WC1+q{008vMJrQ^u3CA44UPMfqtkv`=V-3ma4$fcNtCk1 zwBcXQ_OsfnT}!JfT91Y;*P{hvW(<729laf34L?11EnD{L6Jhd4hQDOZ)D|upFE%sK z%>-zG^^s#c1F9*)N!1~(*>x8_iQja&Q^t2I2V9ll_JG!HPHL$akX)i**Wh3cx9!3d z`6eD(pYvaKmFfQOo4{(j!Y+P!HKqh{&gg%!ZR*gmsg|kTt6!AOut9=d07}y$UXyQ! zG#Hx~`rAutwgU9>qu_5G<)JS;S(||(ljl04&0-C__wPkY)WQR5sE6^xUQvc>E4$_cGx2>@gub?fIcWMkVry^tEYw)Tc$xn!2ah_>-7<8 z^Iq>!2<0;I7!bS{_&|9McweJ{r)O+~xtVojq$iwq?`&*3mb;pr&Y2K3KxlW@hQ02) zuFt~K@cKF&#jUZd!asiqSe)Pot;fpR++SH)At7w+r zH008**-VGYbJ(?BEVcjXXD2g(d9d*GH&^qrgbok15B5W|`q^<=z*1?4bU+|iCVsN$ z^uwHzz6L_7c$%$9G40iHy5EZKec&=UI6U>O^md-HaIG=+OHC{k%c3pS^Fj=5os_51W! zqRF1ZFGlA;*MO^VSaER5-I2K$z;S0sqq;}KBmkk_*HTf$FxSWaUlWBCo~Eh2P^gFFDM*^7X{{{2^f zd2}77QK4?Mri_jjG0l&)FbjwOrHOxWgY51AfCk<)Q{-4BU^td3%v%Yr ziYSM{%3(>))+_C8BgbY2i^Yy@5^iY5}L5ob7aef>}@e#6>8 zmj%2mxGOSTz}d8s(9bU307pItq8M=I9AH1X_%Lo!zHsmRIfCO~p>)l4BZhr-cFWLV zKbTQo=D89lHM}m9Rl%wA$Z<2xBKYM_eGm}dJ=PDUYj{!h)G}$F3 z;Cn^>iG+{8dHR86H62tLEcE0;Gp%Z~@|v;~FO8lqaxYj*u`t4j?G*dNoo~@xKS1y~ z7oEOx*?}|84q6V{TRXDw0#*tJTn>Kf&0L)wfH0R8cOsQ?E`~etkaSDtWIKzV9^~7W|K&+VOD8?D4J(%In zG-@)uZ|n@2O#Ff4(8X`&hMBp}UOrQqX9_;gbA2ec zameeryT7de3*^5cr0XP8y;S`TkhIL(ul7ZU&i5A^Ds5O!t4L5FW2rs?acmRu1BUiGAO(wnBetiDQA^DU zyyH0Ic12q{>$av5i8tP-^A;%2uZDQc(L`8(!RQ##{kRLD2)*KAEoUizBeYYJuUKO~ zIs%kf&@QBUI8&d-v1)3q?NQkLROGw$+_;0A)2}JGKgrJUcAGUT_3adV&c(|GVn8MC zf$F;0Juek!^c)jdk%IhO4WRhJ4{Pf4fY+?;5UH#bUq()>@Ns;!Idn3KO??T@K4(t% zqybLLsdU3tCYD)y+I2Kh$lF*x~#i4_a z3Q?f^okv2Tf||OH!HET(uiol^z!63VVIKms0YIW(cAeF85}O9r99^KXP4b+Nf!mDy z+hEm^`v=vV@7NM@-~PskLu!aEF+__u_g&dUeXX&uDD-cpzsq`a{0rnYbN z?>i4R)U?87VfQmZ5BYf3ANBr{o?Fn!9ncsIJqC za_6x|g3JLxwOiY@{hKN0#-sJ!eq;laoxT@iT#rSbo*k@6!kj%M{wNs#L8Z8L@kv!Z zXI9DPbf&AIJEPQmC>U;b&mBwjl4od4p=HY6QD543PZ>uJ^2b^93&$O zi3!S-$tZ%=A}IJL+D)5AjE%{bJ{aprSf^VI6MZ8Ch9%f$bI%=VmN!wu*s|@bO9?S# zeq|qE21?THFSd*k&$WwRaZ(C_n#C|H>?D38Lr|KDCEa;Cq!q0DUwGu9Uq-b+1K`af z2szv>hG`}a6r+}*`c8eEk6!D#;`d3WI`%6IWAr`XM@fB z)SiZf>O1$H?_gP>+@=N!M1~WEN(oTKe|8;ng4z5unbhUodgcc-8&fZDxEag~ zp?XgLDEwz{Z@4$<`m3A!hJ)9kNCm+G%>5CUu#vFD-~c0erT;zO_rEpv#8DTLElLVI VuJI0yzz8D-SVsAMg|uP7e*w$l)%gGb literal 0 HcmV?d00001 diff --git a/tf2.0/docs/img/FastRNN.png b/tf2.0/docs/img/FastRNN.png new file mode 100644 index 0000000000000000000000000000000000000000..d8826c49388059d98fddbd47e9283023e521e7df GIT binary patch literal 11485 zcmb_?Ra6{Zv?i`0Sa5d;jk^YEv~g`*g1Zw4Zo%DyyE`GcyIV+b2u^VKTmN12HZSuq z58Yk0%1)m;d+$$ngo=_31{yIM3=9m0oGe%k1_l-cY{#M?0pC&UJhZ?MtgD)g1We5o z=^?O#U@5LB4g*skhyHAW2<)Re$?CYmz~IyU_kopDqdkX#0lCP5#WlVdo@67Ze~?-Z zfD8tN=P=5|Gg7KzLAR;p)l{}%q^8BxuvO)PUMY1@6A`hc=Hyinl-a2~5#Ys9sa0Xc zgAhazz=4uMH%m{mo&Vl9w&Qv-EIbLnOn4~%&dE8g?b0W(4sL*gQW-Z*_h(J(q15Ee zbs&7mf-z$1plJkgRS$4bH5dX!JPeE}B%8VBr^L&NqrbYgk#BR`Rv3xKWG0Vh61}3vXd+#< z_5h}qadE$q|4VTuuYHB+;{lB&xV%6rYRoQ;&(W|ooVDw>?qce>^Q;s>`QPEVDCxJ2 zw^wc7yEV@#E_pmQLw&_8e$_Q#2vf480X+>xdpn~^-shXWC@z08d2CAvK+sKL+j=vW zr`wZx5%24nB?HfF{anoojaj`PcFs95_EkNXJrl-ly3IB&RD{iX{#TQHQ(X8EGZSvM zwMuQaY%WWsR>S?t@B6~I?a;s!xX+J=J`b0pai4tdoh%njBkXwXe&<2bSXq|rH@bXp zB5)b9EKM-_$6iu!Iz9zuD%bB(PGcX!)~OLg#Pz=?kz6*gBj25_ zRf^tk!dr6lhT8^)R(D&^)AQIYFs6un_4scJ{K7CPSmW3_uM>rb(43j;rm_SIEv9o= zJi2f6bWAah-oIVM`HzHjIIIa;azCQ#B%lBL)2JqTeq7rf3I9E=b1&0YBS+Y4!F|bK zzvH7z-bCQ%_P5uk#U|_dh21oxQ{m6`14!ujAe9(m9yQ5t?=`1Cy>)OAdxt9}YW0dp zF=3PJ+OFPd8@Q)jRm@;D^(=ROZSj6OX^~P%AP=b97V{EzKT1y5y$^Rn61^C}#s2JV z@?H%r=&~bw^RIDQun-^eucunE+BOQGWgruU=dANCXH|xgdG%>QIB{H7iH3|`1@8XM zbHlf#La(Lfs%8^K-dHi6UA!7vFR@sm3DHnN3ofUh&K7DB?snV|#kLYA7jOYjW(#TG z{QIN6SYx1!G5Mi53N~WF>aS&@Jp^h+ve=q+#^M|tPHY2h-52Q`)gEVm3V9uut8U5e z|73l$)f~m4(QTya!m+Jo@*jO21zTP)CJh4OD8sH9wqHr=)|` z;1u8BhYwwLt(dAohlapfv8`vP%OdeU8PR-!6uSL~cb#5W7M8T^T2;{eO0DXne~2=# z&kxNUxY;6a=giex%}a(!1)D%R5TFVC%qr;oE5FzaI zEx6t4^Yn@fGrer((w^)2_#pi%a@)J2$AiDhS)L`Bii}k0$7FA=Ky)sq8TlPXBMdJ$ z+Yb6iC#^e@_NGu(YS*S_up~dfKAp9Fx!9gE-0SnKC)P4EK}fbzwNg5C-X1`;Ml_pv zG@Gn+KmPSNlP@{NToA>w^5ZbflW#WE5matO#%D+&=mQ<%vs*30tzbCHaGXwj`8zy$ z_iEOqRF+!iC|6dp@%7W+SM9s?ueDatYcKR_5lw_m2nFr;%BQXS{S^UPP9 z=4No)zw+~NY^8H;Cu^n-zZXy}1~EN=dkNC)Q$^)yZ2{$`nQL2U1g9V}*Z;MZYnwzA zwF=nM5hau%BEYO142fKEXjklfeLSAZ8B8R`f1-^0wj+b{p;c~gI`>Sta)7AY=fTBw z9`i$~NDwG9U1Z!an{gabTL23 z=0p?oNUvo=)pEJSG(4&-I^w90=G_;u2rIj|q>!{CO|(04zOo7ckb#RqmU{wD47Qg??+pm7R^# z4%{-dsCP1kXSP(Cg5z^wgx}kAcNu3fr2P5eattcZ+WwD|BY4d?gu!#oy;SgVMYVaq zv~y}C^xh@=${!2L{YGQoyW0~8KbjfOcdt*6o!IM)H=7e7*RP!)ni2Om+YQ<1f(f(I zbTG0&Zc8su4cas;u-*$|m3b8ihv{&u^nRl(c8KU8vq0bz+&(#W zZgf)R`r3Wohn|@0@_22%uLJqHMZa?=fil}3;fv-Dl=|whcSO{)j6U@G=+h#1S;^ia zulIjVZ!cDWBCW9g7nR@>9|6X{dtVWru&SWXQtGWha7cY0$^F!QwAz~V5ZhAD8`q+J z$7@@0+Etm)o#ZC2V)v&(Q!(cW=au9#iu%lD*t=?EyV zt3+}Xt~gBsvYno1hL$_}^q_C1eC_qE*yS`bu?NyI#D@tqGGkq>7jETh#dDKPE@@qn zVbvosgYFWGV1sgb4Bsmj)HeOlv35#t7lk}Q5$OpFnvw;@#R z2MP`g7tDJJf};#ZZwsyb$ZVoND-)klDJFhv5apS`vLb15mWVDc)&XmqzMG6S7^u(X zK*@=Ei=@Qpu4d4Ox42B+z}@#5+YqkIcSYr?w8A5$*;9W(@~m)}zEfd|4%2KoJ&_I; zq#%7io_m47EP?HCELp+e zCfB7$cJBP%k8oS$GP8VKw&7<0Jx2aGQta;&+*)U#9CjS=l!WZHWC^&Iq#vx#TN$j2 z!Aq0;{q4}Yi*Lydj|C-~ZFRMLnG!lX$QJUHSCWy(u${>hUrsmn1|W*}(L#A*go0=c z$7q&zR&b^y4e9m$#owEkyN%q6AE-=Ed*2-ezxMXOah8Pe+WegOuoLfP%1UTIG5k{} zfOC$WFV|t^NA*H60KnvjCW;uk(<)6eB9xD&NB(SOZnZ(!8hsxwcQ5^F8fAlR**#BJ zmr)Em0hCCz()q+NYD~eqorQEb+IB-ii)lMF;hB}NP^H%*PVmt?$?qxKGdlr&D+@@A zvZrg;#$N0++(kpbP0PdQcz6(=U^0YxZ919pw!pw3y-9uISM9UNFJ>7(&t5LKoi-BN)^v}V< zw&&BfK;vd(-*n2Hqw%QlPoxoDnen6L=E`U<3sI7JWp1Vl>JenlB#SmbY(9|^VC`HI z%ANb7YRyKr1Ch$9Ri*-4X#b*Uzx>=OXw!tNT*o^hl3C(^!*%5al;Vt8Mn^Yk){r}# z2*oqO{WpzTLjxAW4(PL5&8@r?`t;WKOM)g5UdPNjbSAa?>{))_L54sn5>!?YE51vE zeyu12E>FBKE3X>G{Wk5nm%4KAO2l6&tQH>zb$Dg6-c875C!>)nbQ_ zhtWt#cl4wAIfv-a(&_frQWsc#jBx4OJsTt6f{`&Rfb3cBvO5-`ldctl(}MA~lAFpU ziVC7-A;YN47hwF@4Vw@C0M{G3Hbx*uO`{p{u}v*?w^9F#y}VY9K9jEYPbXVEYhH^8 z)^qJG0!2&*dVH)wf(Q!QWxsuoxK9-ov>IA$m?Y(?ta_)>CDVTzEf-!eMG7u=CJV{! z=hb#Q1i5`{Isz&CmyR*016QX}K06g&fnQ*~5>hR`kx~(0wr*hp!|6FZtT?sq?KORW zNypy*#q5ecKE8^So9S`um!Fk&Ah*n=9(C@uKBm>>4u^@M>YUVOSRTeQx0aYUy)11a z&2Wm%CInuiHwaOqk~gmW!cl=;ce2$Oc9{WJRLNNk24r}dQIP?aF>l{7U2ImD(QLLc zE(do>I1BtaWyF6CV#-YQ7#xuBPVU~-5Jd;MRQ)d7kVmHuE(=*eZnG(1K;dyw{n8~9 zM~)MN!nYulzf3*Ll7J~UuyZpT3cGe#@ik*C=J3uCSiB%yg~OH}IJqJo-lQqf+3kv)1rnP22t7oFW;YF67Km zWTA8NyI?MsAr>q6?j>e3iEu1Z!=a!GJq2YXJXx+n^fNbQE(c$yh673dPPG1rWv(y% z8cK|qi?rJ3uBP=i_ev#s`NY`b3~_)y(rHhuqD+OK{KAU(`BFD>LiISXk)u~{I@pA* z>wSGNoA{LW%N;-F-+HGECJBeUOU>oF3fS93vnI1_T)V2^iU3p*l?E!!J^Zcd-|E(S z6&sBz)q+InVoQ%^In8&#F6_Jw?gRqYa)&e&o#2;ryUXo!KU*Lo?UjN1qUK{|ZI_Gm z^EYshp!<=T?iJ56bY~(%eYw328i?mLqZP*P1LRKkTv1;zBh>)PcakK-?-@E0`x)+Q zW{5@bWWsK;{*4~hIZwSZZ7wNCkJUt6H~U7QMdWB0O0oe83A(x-tERaDwwbGGN15gR zAVl;m?6(U&t^T%3S!?(7lzwJNhC^4; zu=!PW=xQD{-u-&RW>6>{GjVp9gJY+1*(gia^B`A=`7GstBvj2X(%e%U`iAnh1W=p< zPcusBNJ%!2Z7RR{p&F_g8SaVM#6C!mEnYE(>r>X0p&zxHJjzo0dnm*%I4;f9EC{_v zWTzS;CA|tzP&J)C_fwW?YOL^yL`j(~K_UdZ)jQ#BGX{2vEpC!wMOY!<-3z^9?T{06 zaa0iHqqJ8j2cCKE2n7TFupl27A91V)sngElm&y@LzO4^Xi@o?wSnO=EN4N*L1Z_06 zQJ#W-12|7uc9fTgK>>NIgGPN+b^46ch0$Z3!*XgYnia{(5A5|WYt)#Te2(M|dy(Sk zAOZ)iUzCys)Fuy5PG~t~>RvxfaqMHN)3-_r<_VQXgpye4GqSpsqcEXVK^d|{tCkGEDn7rv zCPb$BDfwe70I_1DJZPiNB>116>34v-9QwUHT62xs6c!)A!XsG~I*-bWzv9?XXlea* zB~PXEO-F>~;n(=Cw?0t#P`%>_`6uc5K~O!f*w`lDAi+4M80%~xw8zEq*{LrSkF}~U zQd$2_g1@%fiiQ;d-{50J+ZPY`Iw)vXN*@Iz&l_(1Q%M{U)UyA(=1I!3q^E#{1{XXp4~q)k2a8MrdF9fj6DE7%}$Li=3EtAu}3c9^0!mB^1Litbl92?hS2`|F(7 z>v?5!&?+4K7FaZL|Lr0gMi+#?5rG~Uy6Kd*RJ!#7@cgg6)I?o9+r(}%%}|hzBrt{^ zPjEk<4G9Of4u7KGoWoin8@KH*h2_UNq-6sTXZ*7066hNFvO^+=$F3esM8!NAfzQ}Q z&TcH$(358uzCOQjWKQdYMF<)`5uC7|gs4m6Ck=cWN|a{hU$a)?_kVphqadauz!W|Y zXiAh;*K7Bv22F^wU0y%xHln~6OD&{2cUZnu=@j`TaWe5zmw-wK!}f@#3hW_HPBHA_ zF240xHk4-4)=DKoJ&}9|MKYm@aACUsbc1BG>U(cFT2kIB!v(*@I4k{*GC-|7#dmBWr9= zVbRl4&ft`y5dJdrpqwgeI$I!hDMS>w8-1@_m9jUG9k$zNLk?2m_a)a9n{(Oa0 zYrdGQ(klT|Mbm1t_>oC(+cDJ>lQ!JA-VEQFt?oIHKID~~(#(s)J0v55%F8u0C8|9c zml8BGJN6Eso1itAi3#T6L1$5b^1!+`H$geknd=$|LqZ{A-CV(4zE6VwOc#@j)p|>R z#z`5ahiiv3J!x+WI{g4SbX#v=zrpQLEuttfa{CNNJ@xkhS``RkU*IORMkj)1)##fy zdtHU&$#V0qBOJU{{8m^>fR0SO$Dg7v@x*F{(iobhN@Nctt|TXMoAAS7L`o__J@9b& zR2C6NRXt-8;rQ>#g`pD?#wIfp|Ni)0PA65krqxt{N7X)Z)TeS0li!|gurYs2YLVs= z^zpjC9bqnf46ynr(iag{V50b)S22cEKursSK|6Z=;r=(486c}_0>X_8lt&G17f^s1 z?3*GQuvMkzA#91HtekpRu~rkK;Sh;6>OWlag!17yd!tDVW-N&tnY@x4vj^n=(I5kn z3tod(8%2ZWq3X-u+qsy>JWTjm82E+aM^C5i%2U`R)fGCf@hH0)T$YI|Vx`|!I5Tfh zby{X#K)ir9ryYv_IkP0A+k&3Zpv}#CH2-t{Oy}1-!80_LA#my&pVbVd<$%9k@(fP% zzx9_S`Ldj{M-0L;fV)(WC9_t`Co|SG7-TB0>HmoSs8&=rOMP20q^MUmWzdeec8+xt zvdWVx$nf^g2(Q&xgCtMRoB91+m>&C-WgTs2tAi9gvsP7l6*PpY9Z+uf7xUq}bpUl* z>ig$>?2qrE-jx6p#?g_&Es|~)3xpd(;o~tK2uDvf$3(hoblIO`iNo&~XvOmn(e2m$ z>6Fif*Vllm`l|fH?$7o0;e7V&rip4ov;*cx`jOFN zKpAA!uF2{kKZ#12d@;!XLjp`!>hwE{VzWvUosLUDH3_c`4rP;UO<<533v3^eqTKxI zWj$l81LJyKwpHOdOzHrlW7 zdhH17l9Kk~d4|aMd+(-bo8g;r(Z_TW=%KgX_E=r#pac(qswttb@V;| zgXo;#Ot7QKpsYA$f3D!qiK%0WN-m3JVj+*J`z#+sBZFJuZPA@<=Tdi4&#VM@@N(0- zjISA(JAXUWz>a-E(-Y zbiV~W-YgmI7Jv0PQ~i2d4_K!P0E9QvGv>xGvev@vJLd zHgZwYl3*l-91Q@`C~4V!v}e_8)~gU*m1jBM$7@yl>DiK*cR*+}K3EISeT`wP+!g*s`)-al+6*vmD!uPcN<|IoRCdR| z%NY2K(v>S@L84LJ{>6P-;Ba*{w4D%=xbpWBj>~%jmg3QQ{#)h}xqFW4rEZ~W#5{Sl zRv{VzE2udleFVs;DHqX{lw#D3X9i7H39EU4pIBG~`Y(F3j$XxFm?RKPwkk7}Fc`#< z5y*^*O+~59Rt<~0Nwhd+@r z;Z&VRv+rafK&moAw2TMqdIZ1ju`QOq!O)a_cfl@8nHu$q1Fk28a*SR7&YFWlyPK%mr z1)8+zw?HR@E}vGE)26%>UfmDGG@_i-dNIBs1_0TlW5ei;ar*mHGW5n|5anXKpP0wH zs;k^hpwoSY*OHb;PHrBMO6E%@r(|uUV~GpVKd{z#zNQOq%ODA*#m#aKY`~kS(I}wdq!HhPE8LuXW{Bq(n=a}?SBJ`5)`{(Ce-EO}^ zAQaVyXsIxdi_#4B_8wI}*=?}yDI^p~WzD)O0A?6BsEtJhjA4ap^$z@dwWB6FFIV(I z!3w&gqL?GBGeMPVK_kD92@okYfLY?svzQOC(fnuk{%66-yTy`z?D^o4PasN6;K$*k zZnY1{NlC1EV8_nD)Cdn5lNLI+kOkVE63)h|bRvG1`e_b4ptw;*k%MvaIZ>@AhJ$dg zN72A}W(a=odGSYk7A{)&y=$?*kO_%Y#*zvY>`nOT?cTEU%|UGMHw)u5bGzUd);s|> z5$VAqvlDRK)OMnCzs0`m3JBoJ6mvl*%DN4?0x&4kQP{!O09oThT;Yq%!`MB?&2I~ z^BJG-sKe?#I1{e?gxPrAf1`79VY>Ud!24Ekd&j+}V3YqU(Iladh=Ljl z%VoU}ECs&E$^KL}v+g@Gl4m?E{uOL3Fk3L5R&Rhs=rr2a_vB>S?n1hdJQGAT4?zJM zK>Yi)S2EZA`dE3-tERP?E9|BA{f~j2FU?``UX-)(lg*!^3n;F#kMr}d@AsN0*z+Jk;65GaS45iQ4R=FA5VxsOJ}LuU}t z5Bs4xcUvVu0=w)6+$?;z^~9*}hYs|{mbAWQcMt`I0|B;tO~#4F$-%dz?BQxrMQbVrXcE(sZs{$v37 zEbDNp3nCo^1-pKi!F6w9c8^+N3hg~Y0mnqNm9dU^8Bdkf;Wr+$1C+IcQ~v{?0v`oQ z$sZ6mJ1l7e=2&D(<)IH^WTe)O?D6n*Uuaa8CxXnFAr!!~^HBtBsqZet0QM8dYfHF4rkP+*$^Kd z$jV{@`bv)+7NjDu=KQ{K8{+}Im8q$7k%2-yET1-K{HM+RyKakL=W?rUi=f+m-NQe& ziF!V@rQI=e?p=urk80hcWm~o?+;<;*AYi;SkPh_|twuw=EV7@Vk3V=@7?*w2pSnCz zmZ?jE`{NYXGR8TH8drF*&;piy5i2%7D+rm_TYXD=AhYL_6uzi9oqs`MhMfMs$ySuj zVnvEkvjNd;l@6D^hLR&Lcd>JK60OvQW;hPmz+~JKKrPlbBenQ^@{TX7ye$2!a;l=!5z&vs+ zfDOfM7#=HdmBx!sBK-dYnNzfR!$VAM=|rJWndXNctG;_>O{6)mkt$ z487?iH)F%pC@zie(c=deoo~}{Q4NR_d`7Oe&w;l;Fh%e48VRxqx7gww@@tOGhN`yc z;Zs9BadG6d5XLwg&{BVUx(++(f95qqN>@(Ki~lOs%Vy4j!VbroCBw_~YrRn>lW?2U zX%d6mJH^Y&@B=K7Le}WS#AwQI5hQ7CO%rewh2uM)nbteeM<8(r$JygW(JF$%DLQ=PYmOkY_9E~rnzBHSjk)dprcr?&gDoNsY$KE! zFL#cZahp~ySp#Vzy(v$Jh>zQk}`t(uni%1E_kR^WTo4*%+AOw-3HKYTJ$w6Ua4Inl(n#5j$UDW11*D z^84g-uKdGK*~PJ6Xv2xI8}kxae_Nb}yxpk zPbyk$4m2N?LlnD0NpsymWJy5e=UBP#21Uz-d=#0{V0`L11Iu9(TLKLyU{`S=xpI@O znIWIOC#=-c^t2%T2{FUq{%lS<3(+RIDT1iLwv$K!Wv0nbQ>dW>>2~QUL}Zod-6X>? z0&>utvq?|cO1en8+wheERpK;XpkHJv$8#nzee!s*vUD@4T)jWyvkSvrkZti8!hXneAF~6={wuCVX|)CbSE&SH1AJ0 zzvM$rw6fStO>!blE{@lb8;HdhzRwVT`SP6WWSfHzDf-*k(>EZ;rj7PZpxIEkz=rQwF-J}J_qcstL@!;<=C$TAS|J`YRO zfh?=GPMVaDO7}P5%>t-wmRtAB3^AW!ehjy&c;+gssw_TV52y(WD!+gj!^(xd+DKYy z@o!sBWij!%vtDyibBXRkEw(ZZbWWHHKoJ^%@Rm@qBHfbo3BJzy% zTgl>MSi9R!xT0CMGPZ|dE^TiSG=vV3Gr^5n!1 zI1-~K1+J3(kB-Q2aLGGLKDPXw@S!e5E#*|PmsnX-8k9TOL2mtd#}@1_`z#m8N#tGI zhx4UEG_#7b2wL*@{66|Mj!~d9rymU%AMCz1%m6J~EWR~xCl+!d)Mql#P#U~I^jOz1 zQw*z+1q)qJ7rW!S@FGwv|UBIwSgE$acdit^E$yC+( zAS)_|&->}R7{uIICy4&3z%eWoz=R3qwvNU#L3FM(fyFqHtv1FG)VeLcDjCEVoJiv*}noWmb=K(k|j zaasPKZRnq5NoY%F_~W%d9>A7|_{n4md&c!(nxdnbyWASa`&!d_@J6Brvuj-$N^Isl{|#tKhLMv}0@p|w2mKGaVI?#G literal 0 HcmV?d00001 diff --git a/tf2.0/docs/img/FastRNN_eq.png b/tf2.0/docs/img/FastRNN_eq.png new file mode 100644 index 0000000000000000000000000000000000000000..bcf52cb29d8c18de355bb7d5158b46385f01c9d6 GIT binary patch literal 4709 zcmc(j^;c9~)W>H8#(|+zKsp3zP`U;X1_|kql9GBrI;3N8Xix^EI|POhB!&`{MrrBp zp@s$t-+BLscdhq_yYD?~-FwzP`|R)8pM7I=wADz67>ED>0Es$GMGpV~^5dR$z#yD0 z!oUJ?H=viEni8P;_v2mM0N(+s1qA@=k|4L%1h_GwJIvGz0HEdi-wUX&$B6|1DE_Le zKn?sY4{{)>hWd~C{9=+{o?r?-*r=S8%f3%OaTKVb{0O5)eFT}N!URmC;m@nmqHJK) ze9wl!NJ3;xH+8#u3j2|z0f zf;%jJjPXFpqrf^7{}DjRm6CY)q2%n5|F2e=llo60lZ>V|c@#!-hx$*IX}v)L!*iYl zMhR`7{TW@k^GW@;HLT#Nj8WbyDX%RaepL%g%j%aJ8iqBN?UVV?$YDz;d4Me^AAbe+u42dkGRtI&y&*&$wRvG4gpc-{|^ONa)oa_6Hee(Qnz% z0_`ml_;mYZNXU3Z;a}VPvoW=)Qmrh5_V?$o4*MZmpLy4&FRBR)x-G@UQ6YZVm97fI zN+V7_K|v#Xdwb{Q;7h~K+mpDbJUpxcCZFYl(XYQb(@Y)Zh29sq42_JG1xqZrM>>xe z$INvEw@Tk)e@uk|@eJZ<1t!z2Ley^{5Xho!5wSVP@X6+oAd_VC@f!=99msKnC*#NA&ht4W!we5N(D0>1M#Jh0mUMsiil;11@RISDdqH*rmPg@3Nm^*9B96UQ;ixc{|nNb;~XB|)` zGMpvxYVrLZ%6V%zL-ixmAC#He{MiD@e&Pn=*;cy>B!`!(j|Ndy;*I!BQbXW{C>yTHn9?|V{8rjI~v=_`h`T~ET zw9UQ9Z$$30y!8-0WH>VmTAaVTTwYGqv$eAN5lu3+^($-Yvxx0vB$>Qp@v(W(aZf}0 zvUN9KKwbLb<@6RVJ-M&P)L2|OMVhN|(W}<4G!kfvrv?#0 zAUQalW$HP1{KfD!-Q)BY+sa(563-|hxZ6=qHCSl_hexG@Yu8TXn|BM%#MS)5ci6Pk zus3s=Vh)_3K7P}$KT~ZKOHq{io-n5aT!>1k@%mWnLaqOv?HULQ36+VyftR#;;c_2c zsVd62D93J}3wC8|33XZTP2gOY9fA-M!ItVBUTrt8Cl<*@x~#9yq(|RoK6l~iQ{YhBoavbS`cx%(@H2O@R{kv`rzJD5)o@tdi&=2VVQ$`yCblqzC}($G>%{17$iHdL366l(6} zt8R{HjC@GuEA73bm_QFwxMDs#T%0%k=8Tl$U^cVfJ#0&n_EuoE?fLHqk?ZVNd)1FJ zQR5$wU{~N+8q+XwcD(h<%~7W@ncP|O_~N%>u~XMF>62Umzrt5*YIOURoAZ4cnP|l`bsj#L zxEj0!BourilYR#mZzLn>XiHMKpGmz9|AP17u_u`1dvUkfJMNsDwVsQFwIjSm#Q2}+ zhiq8Qt;c_R;fX4=w3kDBQ(!?N3?9rkNtDFqbM=WdzeTxc_b~^?Z=7zAJdvVA z&&0fT;1&`Zk@G#-7*Iu1^-tGWy?eP!QLHiZ&Ba(+jnA4RrmNH5FAr!3!_Ao1_1b=B z?o8wIaM6Or4tJm@JPYFC{YAImEayXBx$5+oJSrKq8yVZ4|oMS3-gYM6eY&J3L-{};jSTKJ+5oFIO=NWm|J zw4hrFdP*OS4v4U07tWy0&vp4zrJHWyOOG6pTqmB3!0CA0QJ-B~YLgmWCOa-^*H6x{ zY3I+0$@coqra^ou&P18uU#jGH?HEU4!g$9%k#-rV?p=VWsHi7T^Qf-v z-~A0=Jn79UZePGb@(MQqj(&hXeH*skW5)W!n0OiMHyY75G5OBE=c@6KbfKlK&pkFZ z!#W@C;kRMCAsg=^;+fD8IzjIzPILLPo*OXzSYXD1vJJ##tnh{ETC&*kjhW+SyQt7i z68^(>64i2lD|QeG19el4R+hLyZ+SV1U6#13XvrB^^kx5_2b8y;iG!fYPCd4qB3C9T zlX$x_O$mO3Xa(}kNK$FUAojy`II5MEHf>o~X~DMopO(a;%2(VGI)a?-(0YrR{L&yq zN4RL%rHjZS>Emp4A3Oom4S7Re^;KK6im(5BL*m-QHgP$ewj+_2*fAVwZNrYt~PLYXdrDUhd4~f79tx2PT6a#zxT(2dF$Pg{j9j-|0nEw#ujRv7Ym*=Z7D!EWd`QPkVl1TSM~sGlrw z966#I#T{7`p3FThPSM%8W*F^x&_6%;xab=Xi^oq2k1zH^X}?V!a_kdh3s9WW3W)CR zNm>OwwV)BfmJP~FbnfHsh+WYGG<)t)no0{k}&c z%P8`PS0yq(#Xw-kZH)m-^48KaH)V^ou$G2!5Fl1xl<$E{lWsO{8KJldQm=WsWIj&|jyp+Nb=U$=uoVs^oNnw}@m z+|QK5?HMF5agjZQbv*X>ePeijFmIgjNTeI4ANhs6NG-|t^z%}k&z{j!2dZaF{-na> zA1eomR9;&HZ;M7N&i1B-4kMQ_0L>7S^iSMcb0sO2tSd)3q4(EkYgWPMQ>qM1!s$fI z>*qUT%5o?D9IBZNJiJl$UAK+&VsCYd!fOR*L0$mk8cU{2V2~g?@D?A6AWmtj#fO&5 zBf2Te1E=$$c%Ga^8poK*8uTYE0ya%px=z0P4_u8bI(mL{dGsn4Qa?`hi^Jz5wJdDGY;>Z#X+e!C5A-j znJ#Ggp=bNIZc#lpV+sf#>qP|Eaj_`zj=t}(nVDY^SUE|C11Hp+Lte1smblk?f<&*t zgNPB^&*oPOU?c|-?u#|RW6gK>x7U~-WD3PWkGdq>2SyXp2Y3z3w3+7E@wG--{VXH+ z6^{q`Y7D@_J2{4Kes}@NfclZT=IG==M0dplatR`o%v&wL{TY$p>{z@YTYn&hAFpY5 zh?!|A-Wo@j-tI(MQiWNQAgUQpXPW>*tFr_;Zm%U{*!46GK?m8kr2^+8gtX<&_|% zW(qhBxtTRvGVAxapQ17h+c;8U?q91Q!cnRkbmu?v1S+3~IvQFGgq$kd9^o*uQ=K`$ z^{@+4_f-SOFFD*=nWMCw6IG^SP3D~;!4AWpMYiH&SNz|l{=xBk&PF_{);m4J?vbHg z2CCNA=gbS2m3F*Dv&umxhKPtY#`Vn~g;gsKl_ zGKMxBmR8F`B|AZ&PuRxc0*Avr;$DV%)Ag^w`J%~Kw22J-7{6o_EtWi!3dNqd5rfD8 ze5>C7$gZ!i8<-^BSq024x*oFVNYNl*u}4n}!hnTIM5ZcDs!9N1Ldqtxm?tVW4|s=a z6+wy;1ol^Pqd7Aj!7>A41h1+#{w{>v9BAw6QVcX~DvAQED5sgI#KYTXfYXc~+Tm2Aegy&Wk^F8Hz_oLxapDtZL(n311F`vPQpwU6w#I*KdBW|dyr;07yFd`1Z76QZGpSHYNY72PA1xJa4>b6_G zca>89b9L`9GNDPN6aBzb6}IwX>mePPh^3o~^DXw|U2t|&cAU1=s^Cm$Sic648^9euBn2)XswODpcWz|WW`>N94eI5%e9BOJxznRW&x^|<1SAZdhBZ0LcENDT0Q{k_)KBG)O zDX*xWXfL3QY>$@O2ZRg_6?eb)!OJT`(A%{q%pq6SMBWG&)oG)C(s3?od^S-8oC~aE hNB)0lRdg_+hw7%;vvPULxL-p6^%vSI)k<$7{s(8`-$noc literal 0 HcmV?d00001 diff --git a/tf2.0/docs/img/MIML_illustration.png b/tf2.0/docs/img/MIML_illustration.png new file mode 100755 index 0000000000000000000000000000000000000000..7c1ab545657ef7ff46fe88a12705f905c2617166 GIT binary patch literal 23944 zcmagG1yq*%wm14BiiA=kBHi7sl+q$dhjdFwOM{@2(g;!tQqn0PASECrh#(Epp>%iL z`K-0i`OdlH8~3iU_u6j7hv)sroWGicsyvdrj!lYR!R=$7Ee(qC9U@ILWSb9 zDV_KRWo-p<`QH<4yN``6f0`gFe>%0h|ivx#`2j3X8 zWXsoXgYWO$Ju0(xa|7DWoCh)^zhWNE5oGuB-DyF!swY;Fy+L#3Y7TKu5C$T#N9Kp$~1WhjT8T>`dnu+a{YszQiReaHa5Gb@_TKIY#a8GORpU4l01A8FGq>(mk#<9 zUY&^1{k?4cK%{+ZoQNye-_Z74ens63>_sYkzt`Tmx5-lI}1Nw%l#(M)MG-|`t__X1;bXr8dgE3@6eZV@H1#A+<3KtD&&78BVKQU75V!|AzzbpOaPYkAwQ;w zpSY|~NP%H3`n~PXY4Is0?z0N->`bCaM?*H$o}G+5D4zLcRQTgcsP5Exu%`kgdp>(g zYMQwvUE-gqKC>!EpYK}dm#$P0I2P!9JhJz8oH_XMj7jG~U|s34cDC`-(%&BiW+cmh zn<DJ2k_g@Ykadu}t9S*6x zfj0TJe!ixEKTIpF!@gnniMQl;hdO?DuE*-5%yGDA^sAijB&^^6{CT^PdIWp;*@7&` zabNNl|Mh1p_Bqy+{>SMLI;A6QazptSaV!7G`P2 ze!bFuw{bYchS+W@=RWCVcw2UYk5(n!RNuTVgR<`zJi>yzDx5pv1~gHY_f2pX>-ero z?L0b6>VMg7_%YS2?@syri`i!*VfS@KvLKsIx`6hE;VXE>26u2}J8z4U`pQM05#o-d`Fte15C_ zW^Lxs^?`Cs_V^+OYWYuDCF6meXW^*vx%o$AIgv@#%cS_NagUyrus+gJ_?o=c=UpnU z`iK6|H0EsZ2|+y{`YgX*z2h=* z5EPf&FDc&{_NC;`AM_*4NqUPK49*%aop**5^RX?1W{b=Dn&P9-qxv}ZuKRF_i}Xf)r{iUDUqB|KAKk5h6I8&ipDe< zh1h|c$++5)S9EyXUa7xmmA`+ZW-07K@vlY5)~p<-xc>>SaBRAGNb=DBk%+P&fwsr5 z;)WdsiR^pVcd%1P-cQX${rNkPoa^hZr;05fX=Kk_(qm>YNU^iv#4A*CKq7f)|G;0X zHR~P&yWqjqC-rg>xR@2#R7B?D-v>Bp_jK-c8t0bzJdbuoi- zw6Tui^7&F>ku}CQ&3$J=tH+4-n1jdtmTl=&wd+#&dWj=OM$B76JM5>5j{`RE)gK?q ze|u@U=cjx>=XOV$?q5thUycuIOI-NTLd{&6USxAqfgV$yO83lbn-W49vqaqwoR?aB zaue}cGD9N1os#xe++YjqKg37Tp=ztb+0mI|pA5e=9Hv^KO{_C`#^*J(U}f66u#?st zdupN?OZ-*6hJmlga$Y}j^S4^NicLV^9p_QQA$Aiy0kPN4$uh6XHg--vSffwyGkg6Y z6L;oWP(cg&De8+SIwEFtdNIH2XGWxd!KY|C>xYU93cy97s8J83?`wFxTm9{>|7Pi2 z;#}LjOO5C&8tXKx3OVM@E8jwg?%Kwus^bjO8SQD<77giMwYpNUr(Wqb*6&5`Iu!U? zCNE4&`S@_5WX z`7ds3FH=(~8;+OR;Mugo!qqbE1RZ?6!`eo^f%Eg`@5FscD=RDasN|`mk3To?T8Mv) zSQ>0wl3p4-HT&f$a53`Ca80~Kbo|esKg?pXcE20F57s-re@91=kdOd?x^$`3^c(h{ zzklyjIx`p@A0IPwa<=WSjh8vi$)fHlDP2c3_4gCrxpOC)d~vVktX?rG%6iVTa_dgt zzt8TG5PYxwrI)Q&(ACvdAxc&6;8xf-{IG8x@6*$1CcQVrPPe~vMF;D5CcJoo*45SZ zp`-*4g^7(l8_R8AU=R=x@H1NuSzs$M92^`XA|ez@QAvr0IE;XZC_pBR*vYdWmI@yq z|K~K%yh4MZ-B_z@v7dp83clde5tNyk8PUy~m5Vk1ZfB+M$^Oh-ltfTuByOuI&&yY@ z%AWmIOIo{p361n&3H(tIvSNl6{i(Ff!P$kR9f%2BR) z^nv2;Gylp;fhiLLI8grjJ3e;5Yj46oJ$E;k()@5=+*8?DTl+jE>+I|tc8M|Iw0E#T zyDw7??-~g=3KtjG(b@TBM8tK*NAE%g1~lxyz=5c)srfoNNoiqWAu2l|s;Q~z=;G4y z?v9JO?5(_HdHRmk|A!$(-(OJY3Bqv+2~zZgTjVvT z%gf6uDk}bkg@v*w%DK&3bL}Rkrp@&q1MvIt1G^nBj+W9|Iy!P6GrbKDuk$?-T%V|E z`JJDiUurilh-!L)epTcAg|4=CaC5U{eCiu~nJ+#^yys^de&{G#T3W}Y9~6c9HH5aq z#ovtl&H{gbaYsiXCq_d{OHxkmYHcm{-`PC1wzjsAs3RUO-(3qadDLY#Dsp{Lo6O19-TT5o{8VTXBHQOvL7bQwndQf zwtH+&qY-l%hQn`|_GQSxliN2MFCxXf_xXoZYj3zX_pYw4 zwwkcG-M=65&2&A&cYFSUVj}nbWhXejhbJd3gM&A%t*zyil^r&9w(FMi3ku2{ro^ki zd};ai?TWpUzR@m(#@E@{h@iz4;? zjY&fq$3#-`(tks6AP03nrR8HvCC0RYL=!YFbgk(hdVE57wfBTvfU3c&Cbc;Lv{c7 z5vi%6!C(Deh1vYuw{PJ*ZLDl;N$|`8Z|T&7cRK8rzdkCNp%U}*SW@_{2|xGLdF~|2 z)5qGT9Cv&&^7Om|U;4(Yv1{kHilN~Ta&mIg3H!6-J$-VedffQr zJ8ES-ZPhLhyH>(v2%jh@yjHB4zkpL+RKOg7KY97ztJ^soNg<-6s;c_w5&MG%>eFWLNk8+__4E{ zI+f@Cn^riPtCqeP^84sb$lZAR=1mKU+3LZRU&8Sg!^TPziP=>=N^{gmn%_CgUb2Ge z$%!ooi=P;=t*tFF8QFJiY-}AnI}S1))63{hxRjzD6mz?Oe@#?DNpU{&#iB`;DgGWg z<$GWdcJknNz4t8DwY3@QIgj9$$~}9=@8jdM`4woW!53|+c6jF|xC)!I-*7~Ik1M_{(FWvlWk4gbGG3=0$|!c3D&u_1wOU;`ZD&hYmXZsaG7ianNEgaLGF;zeu$_oaSI8x|LMz23IEpIde?QV$YIKZgX4xq<;3dKk!RQK>-dISBYgK?rUQ^yX$OhZ0pZWO=(fnBV}fZd`}_= z?BKhj{Ncz2zDvQ+(8a#JRt;_P+44Z13|eAR(tW7C7#JAb+}!@K@~~LdSD4Zy{34+r zd3bp>kB^g0JwM?H4-YSWI?4+T92I?*&2ps4`juqvZb=scU_RUbZiWwQVr?DT(IKDY z0*}Um<N0EX1P3 zwf{+YdSPMsm!32UD)o=FOW| zY}&=ffbA0RxGtsa;~|0asp9N^0ddqlgHZiHQlffPl=w`fm%$gafgJw6v(i zL=pfWluv%VyRWHvGd4DMdTA*b+SNM=Kf{0bk!bJs&xzr58fX+y;?y$le*gZROo^O< zK0$K0#3*X6J&F!0@fQFiwA|eIzQ?=P`z~-`C2C?zh7L~{qH1gJ0zAL~jDm(586S6P zlwMm~tHGm39qtVqq1Dya(c)uP<7HTUfBWF9>vxR{6Q!sJwYDhSwM22>yC~z{RI$z7 z-8ip_rY3*ec-V?}6TNTnHT!=3l5%w9Ml}r&YcryK@9KK_`LhtnzSg4Xos>sO5M5xa0Mo8a_JtNw>{WErW#?{NqYvKZMBEjw3p-a7K=pKVF z_n`k#3kiMhY%MZse2((3_dZZyu$7UK+1%d`3JVL-(F=3)#G(mne6p%+&HL3Fyx7=@v5Kt zpoh2jr&m|;c-u&J4i1*8#y>ky?dl z-n|A`E{5&M+0D%|G&B@&V^=&InPM%YQgUlXMn+bZ#PieYZXY4&Y)^r8sCBAAX+K#^ zaj0Gc&^)@>laloQJ;l!M?tD<;g@J_ygNv){!oG`}Td{j$1ANNo&!0a%KRr@pCl*?{ zb0fnh=h6GvWHJS39864O$@J*);{5#AHMO<8A|fK|a)EUlH9Jea!zbdn*RL-`a`g)C zS=&f-p6u_@E2Q@NAMd}>DK~!*W*7{7!f6~T18y zOA8BMJ327o%y?|iV?hfsHZeg(gprfPX`0zG?B>yjvb^R;45-~a|HC0T}VM<=-`-LS8cv`?Yl1f0%W=ij z+PX7U!jJP)uS4Tmn;lnDlI}JdB{_LFs+n z%!m*LT6aG2|H#J4Y55A{x@n4#!=eb0WXQV;)+5u^K`&Pt8k(mgr6&4QMgIP%edzR^ zPvIyStV?;pg6Ml{X}ui?KA1Xjv;LKE)7tyx%a@730Dlb*T70tXnH3=qwEgkKzX~Zk zzxw)2f`WqL0thnDl6A*bi^Ywg7`H&X>ihJh_pSc%ypo%nAYj^b=xpBJ-gz{0EG#VY z4gc=$2p7Dls16B1rAO)(km{#K*1T&uYtdfATLTf zy4TN651sOAy6Ja5oNcz?9sbFTI>Zjc@4uvAG#-QB&#P2)ecyfWEYh^3miVU}si|We5Z#7#`6Wi9&dLpd-3D%jd3k+f-y?yy zRC^zAQi}V8$VE|p1>^un(9zvJI46e*w*3fKb;hu9HMgXs*T7A`-plpX3IFGFd`Y65 zyn{6T&RaCnd%DL9EDxHY0)A9Z`@u1l*~w{88vvbS4!SB3f7@r{O-)U0OT&Hl_@4Y| zh2r?3pkUm*pW#0I?NjUJvz*;Z3M?9qU-M(e*L6NwUI$7;Kth6W(zv&8S%rBAAIQr$ z<*H_Kn>1gBN9J_4;&|-UfQ#xD-^flSb3h&1XLKR zZ%EYWHbhd2Vc)uS%brFD)RvFi+s^EIRo9R@pyzWO@m$4+4 z>&A^6h$f@pzO$%wx|4Px4Hcn23XkG0(Hlmkmr(F(-PT@bWzkR6dEUrTPJ27XDJV$1 zv)I!TLno&}AV?F?p#&Irz}+8z&^9b{5HMpXu$SQ!kDe4EyxX^LR|9bc0R~e-LPGJK zFaeuZVN;%Z4n7qX7vHTJ;O9;=%~v!7OSHyJ&uF9#n9VEZ%FNoiJvNPZH>QkCUJyu< z4i7!w?FPI7$3LFM-OWu_ULF&!{Lnm)VVwt>va+(dxw)f-prD|~;U*0LCnD04T3g$_ zV0>4}ObV<%6}yG+d|X_3@Nt>}pk#IJ?d_G>|GryjSWkNQ?p{KAs*Ms!VSvu)S8@3Z{It!)kLGXWousFycl`LF)lnw! zqH2grlK+LItVk+AAGj?f!j2!BFVO|?KOz%Rl<+$zcc|YZE-x?7>iX<36$#jfWmi9( zXJl-QmXGfZ+jxcbjiJKFnLl@O+3X4>0<5q}*W^Qj`m4BlfwtA{4e9Z^c=&>LqElfN zqxR!A=u(xo!_EiZd&_~lyU#E$U&?-`{N73r4#%6HW03(?%yy(9>_pf+mfa*ttsrg~ zC_Jq5*S_(oL0p%YkMG7UE<=s6qk&MtL3ry3$sxB>U_ize7J=zf!L01uu-q!3m|)#t zoL(L*P+I_IE$ir5+TSYyTL{=L=!=Bi7=PXgtg?-BrB7gRFrA>F)@T{NWJPpZTAJc) zrSaEG@$`l$pb1>A671~Plv71jJ8Q6Lvc&nvV(Nsnd z$!u%rpWzZJcO!OUGq^mgA`(Uw>F4+vc|^QpfYQSdGNb2EpIzF7NfYsF7xx!%mS&)X z*>;=FygcYr zE*2o%RX;zwTUb=&8JcJC#qH`?DIei*RJ?#?u6y>E zFgzlH$LG*75JL{Lj}wYJ)KxhjA5ra(nVGMEsd#vJY?9noQCCL?_B_9^;AC-re$Kk) zCyPnr|BhlE`gVA3F0(~D=(3It4Gn6cw@^MOYoF_w3{fC!Mu9q`5Px>ic<~LYt>^0y z{C=%ID>=!5xQ`K9C+8;{TPr1v3h_Ol#XL;lTwtMn9CnKm z5)$J7>J@rajDR9Fe$&tpF<|P&F4&@y}2cQ;}S@w{oWZ=V8|_OoN)YpVb^tT zF{twN^z`!a26Y|;larI8G6TGzoFEsa{L`nJa!@$CM1_>He!%Ir8~(S+!?{+UmBBM`uVa5qGC-QM2bgvP)I zwK`-GG?TAC)1^S>`Wg=?mWGFiKub$YRCdU#H%&qgghtiJk2~Y9lL=lEvY+@lk`;@e z(J)`T>T$AfcM1>biVVch_{hr2%I&(Wk=I%P8;yDW`ru(J4hopBxvceMt!o&k1fgV{ zWbu61U3aWRV-;BI>+7O2Ssu_vqm#^CVU?@>F2t1y;GN+xj*-GU zhlhnpJ$i)8sGM5(Pc>0bpc0cf-9!gnmR(a2P=&O#^d?cv_C%u}l~m9*kbq{pjt~A3 zJ-`ClCV9M*_wbA+iz5SCMW4#@=7J~O`D+w+HEh1FQi^#|v#?;nO3(A~E{s;$k!cm{ zvuIXX_0b~wH?o1ta3-w_tnB0@`H7}72LZ1mMFo5WCksFp)gV!WeB}}La1^v-gs1{a zE@tQC#D%UI3|bZ(c)Y2-;OHtvnoJ77CNZh0k$->X{ImLA`9($FK=L7=py1FvPrEpk zTO`a91lQwbW@e_w>}ZY46M9C*3;=ebva*=hK!tKo=cb%Bb&J?#(RGBCm8 zkk_J%2swHAs~}3~)IMWTCJ?}b3J@Rxw=gj-E&#sy%y0Lpe-Y?bp9vrgQV_jkkIO9V z?IT-TrAJ0aa`}l}RSXQSYinyK{bdB%Fa%Hh&_5@4UW#HJYX*6ov{c^l_no7iC8XJE zNH9@QP$0++$X7tutcnmqi9oGVP*AAmMMp<(ZEI@=VY9u1R1N`|svXT^0r-RJ3-{)` zx)cC(sj_ZeGBkMk^gG<<-PO?>Ovs(L<25g^Aq zg$7@=^1IvGFhJfJNLG&m-q@*FQ45m)>O{?(oDfo^$IV0m82R||BP}B%#>MH9#J{Dx zYn_{dj0^|hAc$3t79dZ`>Fa}PNgn5BsB|#jc>|6xQFeB={$ud`9Jgj=Kr2dZ@5oXO z_Le-9`loevoo@T={viNt88kaF01vE?X4jp{PAb3C7XVBZ3<$Ud1%V3*4Om{L zXD|xb43-rk%t((D(B)uax@y)I9-~l)GXDOSBhI7KdLUQz5KNrxF5pK%j$_`x?@`z2 zx2m8+Hp!|T@Qr@sa|#5VA%H43RfV~=)TAXkE{Y451i@g*YP(hffXad9Xy$jfw4gIRdKaC1FRnuwl;zEt7LvQqay5?N zlOlH`M4C30HWMo9<^#BwX=2_%AnHwLMv1HaL$z4STXsQADFb6FuN>ZH4pf|Ho{DI2 zvo?qH9A@AKt1WCaoNzjHa=2DA1xO$r$SY8IeMw8N(LfkLH4l~%XC zI?4zB91RnbT)q$BmBX1}Dh*0JWW^CK_P;=R+1ssXg63ES^<~W z7=V|s*KV$cUrBYfFv@Cgs=S5>A&ngWA&rHF>;Zv+U#F&w(!UZ&#ytUQtkd9Aqj-dD zI8dBzeARJB0Oek{=%zr>d&_CiB?ZUUl}CqTvCGc7)JQt zvdO=mP^t^dFNed)YNmpU-rdvl4LCt5RN25QvY6KT5AP~c0im80(CK}5=au!Zk1VYG zv^hX-Ou-)DU6452y#e>@!+?wW|2f=Pp+Q~$l?CWI+n)I1&IMnCD%0tXf1q1uHA;QH zI#Pz7_WU?cBMIzu5NNp_CPlN~C&Y67*Zt_ZT~y_^rdy~}Cb`^G44@8FI`!FapG;bU z(&458WQk_Q!@-dPQFdl_HY0AQAMn{>W^`X92#&%Yo752l*3E!2;2vC$k-rQmFj*=z zHy|yI3YT1%TKzewo88F*VZkBtzKFyGfUhggZ>L8{r_A&U9sz;mA+7dy{M)xdK!m<` zPflo^0b>&OI~N7}ViPX7`I9Gf5)zG*cHX;x(8$Qho^4E0fC}m8>iX>}mCuLSkWj(= z+}x({+kBf47Wgjec7E!qm@1kSoGWiu^>p+l=oXS6eZT`7xStwb70^k;!4WY!HrDki zZ|HS+xH$qLU7VdwQM#^te0*krUD+8KgEn$x26T*#jbBAA%+HtJe}&OGPz0^v)5wU< zP1y&@wgC3`o^v8gmcC?N2vR~^0U{-E7=dRe>^O}==+>2wzR5E{0!FHx=AyRL|Z?zO*)1$zgLtf{T- zV-Lg4y?W0^2v8InH5yv#f%u?b?TBVj`%G2SBP$+u76nBnY5Vaf%bC{D8%5&c|8<9p zvCaB0LD&FSwFM|^%7oi|N6hB-_CMca_s`egciq>+g3ibywy zhldzZ?jX-2@eHJJoF46z0WpJDV&4ACA`PTtc5W^vI@%v}VLHHQp!WBOKLS?(@tbsX zm=DB(34Qx;{}m!MTqEVj1ok^H_TqX8p7f9J-&-WkxA62Ed>&hVdGLmD23!noUS2f| z_%uGgz9yEILEtbe#HWdR62cvChc4X*x09GE?SHgTtU$?A(6DqRXqkW`!XdCN-SL68 zy|w;j1Hr&Ry*0ubRzW$8{P|M_F%^Kx+B%;=i!t4v?@SbSqMh`Zi?B>9sRZ}N5rjVg z%G3apfMaWUUQ;S7E%gUIy{@i~+ig`R`}*&thvw$=lYSSza0|YRR)N0pbfP*YI5#8% zF#mP1YF?-?$FXjVf9B%Uul~8H>w^WjU!eiO?ew2NA>fi|wq`j)n&NQvQ!lB0W4Zgr zq|Rsi_svYwZ~ze*%Mr>HI4KRCYe3DfI5;>cCi7!KKl(HNxkPfV21GT8C%gh9hgw{m z5|mqPv4;G_loT_dBuFemHi|L=4IN{XLM_bP!eSOG`_T zqGg<%cmPeQ=<8ogR|$Ff@=`H76I0Hh@3^9(;`H+JYY-2>4ZH-PW18>Vf6tIe=I?NcdBFUH?d}lWhXQ9dn6D`+4KytjSOqmT^@p$R z3kndW$S-wSdQ{}K6ZXy22MQtPwQJILc0~hbwY9Z6;IM!}cvVbHEWX_hsNw19POmUL zH0)T7MGrb0cml=<94am@hUzHO`238BSF(Tys&uu>67I01JDfMfu6_LeHP}f!Pe-`H zZZU&a3No3h*Tl9Lprq%AGgrw)JxV8~KLLx;`QpX_VW*$P1|pQS^j%u2sxRP@bae;e zp*HGvaDg3SK$?h>7G$HV5pIQYmX?g-htufNGBTZ_2)u@iWcKsJeNkz~BJc;mSb%2I z1Vv%%aLX*d9VtKHNSJj*)95z~z{|(C_eGbN^R|BbHUlC(^0jalk}DgbT47u#y$m%- zSXkJ2y6FW8zZL%2XNQukkZedxOhdW@F)B<=O-Uzy02>5=)fP=7r9pUWW;iG9IYJ>j z;>JfuXMp2@da*>qDj8D8eB`ABqASuO0H9=o=(@NCdodsqZe)OLJ7h|v0ARsC)p{~A zfs&C^dFo7@ouu%sk(ar8^=crwo+b$FP7sgMuYHCE8dWQJKYGugd>w!AyEy+d`iTjt zs?fA{;jBW_`U6!hEHViy7!*KJkO2=irrIH2vN`?rmF&Uap&~Nmb>Lc7{dp(t0?{Z; zKpk;^L-0^WK3i7+^%chLEP}+6a0*{9y$_w+c z#NnLCzdzH_;ijA0!#55FO}=uTpeQT)dnD{X5_1_UHYoTxts%oqGz;q_39+U6vmt~v z!aL<_hCOnbUtC0rG_So5vjr9c4GYVI#WEmtv&iN-JKJn+fxSxrS{UMagNlP7XP~3HH=I8q z6#;6O2Y`ziIDOEzzmASR&gMZ>KL#TR`~k?9RZ@yzHC0Q50LZk&#Sz}~qlH`R!=qCi z2DL=s0KWi<^mQhkPe^2B8#qbG-b%~K?NCI9KR-J*0Ur=_y1@AOn^QCR0xEFlK^UEZ zrOQ9q0?K3xh|*cPGzF3!fVi-LkV%`GX3$D$lE8L_lki*uAA>~PM_4w3^qaeU))k~r z0NMm*(*zW9d1y#a6p8?t5D-`ZJu~NZqHbY%IU&dr;QTiuVzJ(5dx7dA@R1?Y7Xale zpIb8H5mXKQTepJY&uE~PfEUuy(V^l=4nojjyTtj6DDh*=mZ0nEM@ONSm^e7E8XNuU zcj`gI|JvJ&2ks`+iX!(5kkw3~lL@;n$CXQI{!y$8nnS^&;lCIk8wb1X@|E^(bXMWo=Zteaynb4e}8fnmps?M z_h8826G7;A145jiGIxwyF0tmA_<1~uLy%m@dPAJ?eF$w82T zs`vtwqv7`%4Bwvu)JV_BkhH;m2W%_P;wQZjAKxe9_M%QmIz@ucgfKqXC+s6dpZAb# z0rNSa^9_^j>2!vcd$l`MV=&SoG-zUBaRstByMOx75!eL>2}DZ_2ysG=iq$SD0mO6B z=cgoiRN~S$HmvJYjZttg>O8h2*Uc2)-@TdUvyG*s@Ibc^swr$>M^6v%e9wB9xWCKG z&J>YGji3>`7+D5XXu;AVfuZ?%^FVS31c2@CwnCXigxti$MCSQwCHTnqB_%H*hze|t zqO5dgVq#)_N$8x60!4}-AHb>ts2a(jZ-#~ng?RDU#|^GJ{0h-up)>b=b})imlw9Vp zUId&^xB#2&RDQX2DX_<4+#snSuMGZ-+TR*ITDao?Unjit20OK(a3pK;ASXv%T^(Mh z2?!p1U$~u%#XAa&zllPnp`uMFZCfxt_IT_JX6oCE(Tj+{z$*}}K>{?o8L#j&Tyd$y z!#VV-rh%FEjFrHKCkogSe^g9*2?URorS>fg7Z;LTIzODlla`h3a+Q&KpOl+iiwdx%8TM^!+@Z13Yf}a?IY=4^bQ*B+G5|({j@ff= zh_DwVY6~H>3RX%Bc}2z6i?gYV8K~kYe{hYafixF=S=9mD0mti`T!A!#C;=uTLhy1g zbp0E}`ZX_aFg=2vpXaS&G(ZbeE88N34bvp+utC0@S7@2~pZ;7f*_$Ij0Ck00OmI#`cDd>J1O9dWof zG$7~zzHI||1Z9I`u)II^dk9$96Xj21QveevN^o~x#3QEd*~tMS79tG@T4)m_6(FU^ zk$QLsl0lFc^*BA`K|E3*g<09z>41_uJuz11W@ga9X9J{(0hbngjs5~QZ{~1o4kTC^ zXk3{*h$RR%epT}ijwonK0C;W)gMpU$odZ_z^Q3^M(eS)QJM^OBH|9{@4ldrkBcGq2 zr!h^+Hvz+6k7v|{cPA0^;(z!4ea+%5{B>9lG~k&n z0E?_TZ-G&JLo!{3m4=OrD`EMS8Kkrxy}#?^vM-ujsumJHN~A3@=FA~z8Vj%GX80x&;^O4@#$)Ty3^&>} z4GnQ2Ah@0oIBx{-83~$!@?Bb0rDd*D)J*a!DhhA$$NMm+xwiYt%6QO1NrddlhsoDx zXPe!@w3Hm^e#4}UP)+Ekh=vRlfy2g0Egp^>LWm}gKj$|P9lL|>csLuzg)nrXY9)x! zAP65%;uIfTMy;@325Z%Nr#TG!iRu?n4j9SeA$G6{AH?ivdjaX@3*y=-x`w&gh^r2T z9!RH?#n-Q25u+V(#vx})326oVj0d2|4F@YZgQatus|-TO#(*smun4O3U{ygM#4Zs! zaSG~Pg#DAJPs5Qu0xk?fBcQrK7zGn9EVA8Sp0Qh#vvP4E9SNU;A_A^YCxtRdAZ3u4 zfe7>G+=YrhWDtQEJS~LrL(5L~b+mKcUwy2YCQbp#(mR0Vn%#wU!b2?@c;$)uI!T7~Pm=LC{t zwU9+cTt$!o03=ESUH)7ON#BSbt`V?|08wwA@e{2>5gh>~3{lkw_T*1?|Gom=V>VV{ z4e5FH5hpUGhHKZatAe}@(-<&v!M2)}46+L*oJ0f}s;R|lsGBJ$U~82aMgaA-r)>ng zjs`FTlniE8)@Hy+UD|(0Lt!lfqCX-p3qb1-O!gGlzndFSF?x#b0Ty@(K#$dK`w%0= z4oXtZ1FdAi7MKLaXbQDTe-;w{fu1{yG#!2whc6K$`>3eR1xam6MZMT8a;llbMAD2(&bu$Di2`Z$S^v zfVW1#D$;C?v>w%OR&9EQ9R@;lqRCA5U3#K|crOu(W59)UtdlHYj=Dw`R!1|-t|?AbFE z3d;5&NE+F?4O)ddXy9E;riOj|@bCBOtne{t@VSkM?hwup60rvZN>WnN`03M$jLJ0r z$I#}RNYKr$f&Px5@aC45nsNwPQ(<9Yfy;2;-kuXYvFMl>qyag9_vUT`v%RE$r8D_8 zaty%y`(xD(0a+x_3)m??{J|W#m{(lN=^Dg$5%~zA0MlR5Pmfnh{`|_pLyTy@bDsvO zL18#{NYoh0#361V44tTkFg~fifbtgutOBt215F<5caZDkZiAc-4n(Md*%|jF3&0=} zkkc!$NB;pZm;%-arL!y)b) zPWtZMJ8p~f)dOhagW;n-5O3nu15Wi&zXU- z@rUv`>a6?E{zANpaQC;&Ks*SOyDRsuV*2kz6BX#Z|Lwh{*^l3Ag=(Fbq7UL&XD@BPA_;8Nz7;U_l`@27&?3?$u0V z>0myA$AjQ7ZDMlraNu$sg?bM#X_CW_*>1QFF2-M>ia z*1OLpT_txu*Ie}sX#h%N`oZ3iB7_Ph1(`@lEBFMrq(Bpelmhf~DyX^0^*+$~Kcxz? zD5!hZ)+})RUmzX_WYg8H3z53PaX*0~v{tPobx>jdm{;J^i+k%sly-E)VOxE2R5QSI zG4b#cz!&=9Fx7z6&IAuM|D>eb=X)i7_vs1q?&=ra-;2T@%xDeEEf5sw9R`1?W+(^_ zJfh((g&dZ-u4ui3PY>`j5bn74^TSOO#KDK^pPQQ-J=zNrSiFhzzYEL=40ddDCB5NW z(zxl06r?tgQVTwd6CD}EMjSy80!<!8N%i@H8rIb71Dt6ArPUeuHNGA!MX)$#IIvxq^N3Fq8?D`Ai;%d0w~RmC|y8a zq4fF#yB;_P<{<38P7Yy&?ClW&3iu`($|=WTglqkbvLcZA{jY)nJQ0NgB^VP6>jk_J zK+{&2fP7?li@U9fL<-g%Wbx3=O~>NZvx6x)AR_?pqk%j_fE)>bovgrs9Ry?+oYeob zxfB1hUibuOyoK)n8iGZ$5I<~%?utaip}R_IYu|!fg4Cq_jj1Tu)dF@#yL>;?^vui) zaJV4K=f9zX*=GZO6>Lp{4_XtVJe;2TobLWohUdD0A$_>P+^)j=;vNdFv}S82As8q!0OdbE zdjA?;Y!*q6*F0eOLvA2+gByz`U^BPpS#^mrY$oLBA&n9M#V9U59+}4h8U_D> z77mq$e#%4;qKpsIw^)un?2xE25Db&2Pp^T52mbLJvXNnu}3{d{3f+!WO zjSxe4WtQsV@G2a#BU~Ai+;eMOyFx#Bnw6cUU3=ARibKO!!~%nh&>*EslJA!+MYclK zr$1m6m&-=FmTU39FvkuBuC*E523JV~;*eShg6PigJ^pp&5*kz@k8oKhcszz((LawL zm@y2K*QKDZ$liOgoM!oR=va>b7C7+B3Y9CZ_}O_5MfYrcsq-cGqH^gJ>OyK`3l_qj zi{(oEzi-I=P##1~XGH3>(vNe_D{7$Dg=9cJiNg`pf(nNzO4tuzJ!mL|FVNFphM_vd z2Ja=P!X1>@`?$6jWyQJ7r7&|xnGcU_l>CSjzgp=Fw4 z?-s-IYgrMS+weMwumU{_4kfq71%Gg3ql9cE`Ge3~z06y*()Gxz+=Os0&>LbH%JF~m z<_6O69USnW2>wS4mz9yz<`U@AVp&rZ61v=DA6N@ zl<184uw1c!?;H|^+E9@v16lFg84AXfWujNg>0sx5Vk+*!oA>h*NH2h-39$=56{srA zL=co47+VRI*gK=mm)%qT)HfB5oIUvgEx0nEWN>K8`Zk8-tDODv8L$ZY&mZ|Fe97Y^ z)J%uhQzc&Bkr{~1DpB%X7&r)GCL5GkyEkPZZ2)_d@fKET`E2`*s4QJ?2rts3m~T}2 z5LEtBDQH=VQeQz=MXnFN1$gk%YaMS^)u+fZ*flOHuAaSzL&W!^5vBG!2mNX)`Xr1}m~sRbM|Gf)$*9W1|NeA^ZS(=^UU;Wc~n<5jHq4 zP|p$lq_i|Trz4bcWgzcHcXzij7&cIOVBjVb`X`bOgO6ptoC%wR0YoeWSaiOyDJ_?c zFg0>6GWP$|2l>ZTaX^-~aZO`FmQJ!Fv`bDz9NH#C&mX>JS5>eL`e(>jt)PA-}SrOn2lI+B6r`0f177XWgeX_n_)2tu``Vyk*gTJFw zYK%5knfbLKIYjmcfw;BHS*DgtMzM#%xGFV1^JyXottb>|M1fR#zvWYHH8hA|KJsJS zusVg+Ny%_i)4d|UNNHc#eCa)pjU9&n&+*lP>`nj(tK*;X{-Z#G)&j8oQkl-R|I1A4 zL<6Y3|2?rgWJ&FaLfw3g{9J(lp5EnnjgCTeL@*Uh|FwF)=4%#2Cnd2|Dj6D5f~OGf zKIwVYWwHA!L}$1D^wHHACP1rd25}fBagh)`)Euy9JH#kSNxQ_5&qFTarf4gYZM##V z-C1WjHgSA!Afakl!tQ*Pnl!TW{*RvdZ7_H+{FRb;1EJ_4wFKOtCe(npUgL8yvv3}W zCjdulLvcg35Cl!kJ_QKepf`CZ(!pxVK|b*M(tt6TD#k8c0Wds;HQs z1y|)Z4A6oymUqT5jxvVXo`?onU{q>bxMYIiU#N(QVKNjM+fi#j*xfa!jgdF;V{*;t zIL@(0#P_r7^>uZJxWN@X(8*?k^`~rL$_86xX6j7|s4k1Et1#zrFzJn$|J-dPCxA|B z)&L1SR90SEA08Tl@gz2O_WAvut}fGX9whgH;eQ62LU@_>gK9 zz3uHekDZ8!iJe11kjWu23m+UBDhDNk5A_t>!9w_11`#AYFVT!aK{1CD3M0Qj!h2BM zV8(+KqqVW3uCL$C5DfOxd{;|zb2U9DOnyx+z8L5P^kW`-3W@103Sw(IxFgL} z@j@*8aXv}O$>qTOoyP|U2DrhkvV@TqWD@r+JvhvuEDjy$Gi*yI_iVN z6^g->*D!UDhFV=)%l^>^at1vU)7&p*37P-v?Bi4t@ zL-Nd?0Q-euafGi!h?h0nUevJKUsb&{5z0-&Y~t& zgj*s}tW<<2U0;!XoMV?pA?=;pW6}N*N`A|~Q8N_s6Ed3m`UoJKFtgtH2sX3nV)3N> z^XH{l+J6TsQ2>&9lBrfAgeFlaZpZ0+NEL@m-3kL6 zVDO?tkfves`)d^F<|=h3>tB#rEcheG-9M_xD!@G1J=jOk=!dN?1nshv8V!y9e_Oxp z-q?HT8pt>x6amn{>QrM|fgd2tiJSkU>-)P^J}1a{tMR+Vh?me24jl|LBJ%R$@^Y7D zZ)%THCv};5Cm}MusqpaOZL|Iz+xq7?Fu=H9`lM>%w3?jsSIHap{I+lsy&Hb%=N|-mrz})SDeAh|5$VKTE{75q=~D*D#mpmAHdU}v@R8_~u91N4IWkf`S_S|S482R>LzI@6=er{s0bqE9jLlcHH??vlvAv@527NB`?UG~82w)e zhqN#XB)2~!Wpn^?a%V$s!tlG*S{Ie{C-}J>An&{=et%;UL8cX#8H}V2;5SvTF!~_E zpEcI;jO@To_}#Vth#mZvsa*YTs%O7M`E?^VL0eUICmo(i#m_hg*a>5Mj%#BA|6t{; zbFc~fEC7!pKe2~d<6mTZ9aZhNmOL<)8^AoivWJIqeN1ihz;4KmK8psZ#`l$k(7z6H z8WFWD_(~u4eUWg=G)SJHYu4D=0G9d^O*Yx_u@ML{^ z#p8M>)P-A|nibB#gP3cyWF|ZU13h-vB%REwd!82DmB-{^<@*qd6GqBK7TZmH)1i~l zqOCg5>{{Km8%f{g2N*~wZjE;eE_1%5NclzZ@WspQdtC3Ps7{X++AS=0v3r^RQq}qP zo=|qpta+xnq#jg!cG(ai#|7MZT-Xo*KVZW(j^$Y9G;oXx>*jr`Bf?N=VQ6lT-U)x@ zh+ShD{x#2(Ft>9A2Wd;;_q#_-7+0_U9tx%Q~*T*;-|R~SK6O{Zr70!USJqN&wpiN zf<940!n3YS3iue$lP9ZcL}Dv%`Q&zJ9aTxo7OCsYUrtOIdU`@hnqB6-iT{~{g9F7) zJoX^$A6mqHwC_hd914202_JB$Ew0QG$!&I8$rdj)sNDX3U7i9H_UjJ-xhyqPs!tuc zYwHsdlKPWGjCAu8aQ_uwWUpVpP9c;Nsr}&bpHtjg+_iVqIJIkpat8<7hy8v@qAHQE zezKq{D5zHSHJ0Uc;J|^1BIQ4B;xVpp6*}0w2%CKsnxN(?DVuFCw z-Zl;R?7CyE%4vG-!(gP3LxR8RIH^f5@9XLDO)P0^v!XA-E`=@@^{to@6q{Btuw^=~ z(Iy(3WG|Q7ilG4wGIrBDwrW+MAXaD*65z*?U1+7hbo{O&-{iBa6jx8vqV$6&shZEV zRr>?82LCgX5z7o)Ib&yQJKXg?#XEpD6P#vOFqaQCLE+IVkfHP_k6j?){kH992X7Lz zVm$pgwXz-3L&WrWo;r2$rE6{R)0G&AA|O%|MY+9bl}DGpnIIM z>xM9qkdiW5eoRSB9s~(8n|d_HNa-l*)Wj$&`3wuNL}gFY`*;j6-1ysH~58aQ`CS?=-xfKm|pXwM-e z7tX(T?>X9L`J~1Yubz7+o2wO^k$hB3LgM2*9^e2HN5d+wOXD^)P$VY(u&nHD=Gh6ObLaMw zu~cLP8NsBYrGY;)4sZM5pG(I}0K z&_Jz8tiR|=-?AZ1>O`Z4rB$A|U4bN*%a!wpZjJ8rG;wy$YH$B_oqUB@f(mEf$Bti# zaLf^Pb#(**JmV62Evk|j!V4nUyVcdN-uP{eWy_BRckdbjW4UUk2(?g2@T3zKp=(3D zEJifPb-EfF!p}@>KajGRH$m!~46;Pv4`ABs|A+D60cd!#lEGiG6TG2(vV||4W;~!z zB`hDj8ZjAyPGQS|46y$pFiF)+?cW$a!qBkI$jBfY=SUHQ_g%C0g?6gR7Ih5`yf(eb zI!sw7C1Xr46v-;tPp7UeT?)~6vbN%T3}O|}IV6M4CuCmZxpe{Zq>$)YAK&AM1E3zm zI(HnFZcb26=elZcJYvp_RVCV%^oRG=4y?TTLTXqbwZJT5px5BvzL!j&$+uPWEOP)$ zakH#U5)GY!o{D@_OMNjb{`LR<`YWLM+-kns^qoI?h)YWTJnn9hW;s&kFB<&qa988w zvvH&FNEqlvzj%wPHI(f=)>|0qsF_!l8kG<}qdH@{PX2WeGMF$kd`rGLd{`WP25GDV z(snH?N#_Wfg#M>t`%qK!`ODXF(#NUPlBu&WxkVArpwZo$(#(WVDjjNe>Km}g^cbPi4kkhe_alSXJv8bh@#Pj>=7mVGZKs0)oaD5-78XZ z^_0x>pHGrPb(r1&G`#CQkYR<+!cw{G>c$=Q6l@S4>3t_lSsiS4?EHL?J|yy_+5r+E zhLnq9ZNy0Ie~(N4Ek4mWWH|okTWL8tGZauwr^losn=$=zQw5`P7#jcklY@_Kn=LW2 zO#z+l>fk9D?xri*ca|v1*G)n%#h26wL|TCCkizgp++RYuGQxfOJQHwN0rCLs>@lGl zeBILuwQ9FZWxt=T)2r+i!YnLu$=otBlj;<^+(<3i6*`^6^XO3)#)_ZPHroa0Zr)jx z6AG&eJRN+1ezsHG$fOGELx+YswKbzehkh_<4$_A>vVBV?)wmtfm2ep57MCA)%SHYM z88NU?{4ZaPlgLHm2#f&TJw0d24!c*d6O)tIWLQPRaODcBg^cKcX0;}0!_e7^v)r35 z=3@6Bzh1nDSQN7yBU{_<`z&*J-W5Wn#RA{HDtu021 zpbobEI=(&FM-^xbAzJ8g# z7VJfpN{7)Z)l!)dm-a&0Dz2HAabirHMny*t&imfoEB4?tw=FR8lI;GH>qkb%m<@6- zz$OY;b0=E&B$}#X^&qAVnUPiFZq{d&+tDS8V*hwz!o}tzeZt0kE>8YWuuIq=((Z(4l@znwUY+c4?PMGI{LrId6@$ga zDmE;%a|;NN0JbvwB>WIg++)a8srY!~LOYmaT1H26 z=Bd>8v!4PZMV)6^{!jSk^i{+-2r^s>*LS-zpeTQ9n>DV3@zNuMBq*XpwIvISjmTey n%)7s5piV3P%K#BBAGNSC@Ov2Qm-k^a1??%ZFO|HIv`+jVI|SAq literal 0 HcmV?d00001 diff --git a/tf2.0/edgeml/__init__.py b/tf2.0/edgeml/__init__.py new file mode 100644 index 000000000..8ac062499 --- /dev/null +++ b/tf2.0/edgeml/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +''' +package edgeml + +Provides: Bonsai, ProtoNN and BasicTrainer routines + for both +''' + +# TODO Override the __all__ variable for the package +# and limit the functions that are exposed. +# Do not expose functions in utils - can be dangerous diff --git a/tf2.0/edgeml/graph/__init__.py b/tf2.0/edgeml/graph/__init__.py new file mode 100644 index 000000000..3d7ff8299 --- /dev/null +++ b/tf2.0/edgeml/graph/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. diff --git a/tf2.0/edgeml/graph/bonsai.py b/tf2.0/edgeml/graph/bonsai.py new file mode 100644 index 000000000..10851a1fd --- /dev/null +++ b/tf2.0/edgeml/graph/bonsai.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import tensorflow as tf +import numpy as np +import warnings + + +class Bonsai: + def __init__(self, numClasses, dataDimension, projectionDimension, + treeDepth, sigma, + isRegression=False, W=None, T=None, V=None, Z=None): + ''' + Expected Dimensions: + + Bonsai Params // Optional + W [numClasses*totalNodes, projectionDimension] + V [numClasses*totalNodes, projectionDimension] + Z [projectionDimension, dataDimension + 1] + T [internalNodes, projectionDimension] + + internalNodes = 2**treeDepth - 1 + totalNodes = 2*internalNodes + 1 + + sigma - tanh non-linearity + sigmaI - Indicator function for node probabilities + sigmaI - has to be set to infinity(1e9 for practicality) + while doing testing/inference + numClasses will be reset to 1 in binary case + ''' + self.dataDimension = dataDimension + self.projectionDimension = projectionDimension + self.isRegression = isRegression + + if ((self.isRegression == True) & (numClasses != 1)): + warnings.warn("Number of classes cannot be greater than 1 for regression") + self.numClasses = 1 + + if numClasses == 2: + self.numClasses = 1 + else: + self.numClasses = numClasses + + self.treeDepth = treeDepth + self.sigma = sigma + + self.internalNodes = 2**self.treeDepth - 1 + self.totalNodes = 2 * self.internalNodes + 1 + + self.W = self.initW(W) + self.V = self.initV(V) + self.T = self.initT(T) + self.Z = self.initZ(Z) + + self.assertInit() + + self.score = None + self.X_ = None + self.prediction = None + + def initZ(self, Z): + if Z is None: + Z = tf.random.normal( + [self.projectionDimension, self.dataDimension]) + Z = tf.Variable(Z, name='Z', dtype=tf.float32) + return Z + + def initW(self, W): + if W is None: + W = tf.random.normal( + [self.numClasses * self.totalNodes, self.projectionDimension]) + W = tf.Variable(W, name='W', dtype=tf.float32) + return W + + def initV(self, V): + if V is None: + V = tf.random.normal( + [self.numClasses * self.totalNodes, self.projectionDimension]) + V = tf.Variable(V, name='V', dtype=tf.float32) + return V + + def initT(self, T): + if T is None: + T = tf.random.normal( + [self.internalNodes, self.projectionDimension]) + T = tf.Variable(T, name='T', dtype=tf.float32) + return T + + def __call__(self, X, sigmaI): + ''' + Function to build the Bonsai Tree graph + Expected Dimensions + + X is [_, self.dataDimension] + ''' + errmsg = "Dimension Mismatch, X is [_, self.dataDimension]" + assert (len(X.shape) == 2 and int( + X.shape[1]) == self.dataDimension), errmsg + if self.score is not None: + return self.score, self.X_ + + X_ = tf.divide(tf.matmul(self.Z, X, transpose_b=True), + self.projectionDimension) + + W_ = self.W[0:(self.numClasses)] + V_ = self.V[0:(self.numClasses)] + + self.__nodeProb = [] + self.__nodeProb.append(1) + + score_ = self.__nodeProb[0] * tf.multiply( + tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_))) + for i in range(1, self.totalNodes): + W_ = self.W[i * self.numClasses:((i + 1) * self.numClasses)] + V_ = self.V[i * self.numClasses:((i + 1) * self.numClasses)] + + T_ = tf.reshape(self.T[int(np.ceil(i / 2.0) - 1.0)], + [-1, self.projectionDimension]) + prob = (1 + ((-1)**(i + 1)) * + tf.tanh(tf.multiply(sigmaI, tf.matmul(T_, X_)))) + + prob = tf.divide(prob, 2.0) + prob = self.__nodeProb[int(np.ceil(i / 2.0) - 1.0)] * prob + self.__nodeProb.append(prob) + score_ += self.__nodeProb[i] * tf.multiply( + tf.matmul(W_, X_), tf.tanh(self.sigma * tf.matmul(V_, X_))) + + self.score = score_ + self.X_ = X_ + return self.score, self.X_ + + def getPrediction(self): + ''' + Takes in a score tensor and outputs a integer class for each data point + ''' + + # Classification. + if (self.isRegression == False): + if self.prediction is not None: + return self.prediction + + if self.numClasses > 2: + self.prediction = tf.argmax(input=tf.transpose(a=self.score), axis=1) + else: + self.prediction = tf.argmax( + input=tf.concat([tf.transpose(a=self.score), + 0 * tf.transpose(a=self.score)], 1), axis=1) + # Regression. + elif (self.isRegression == True): + # For regression , scores are the actual predictions, just return them. + self.prediction = self.score + + return self.prediction + + def assertInit(self): + errmsg = "Number of Classes for regression can only be 1." + if (self.isRegression == True): + assert (self.numClasses == 1), errmsg + errRank = "All Parameters must has only two dimensions shape = [a, b]" + assert len(self.W.shape) == len(self.Z.shape), errRank + assert len(self.W.shape) == len(self.T.shape), errRank + assert len(self.W.shape) == 2, errRank + msg = "W and V should be of same Dimensions" + assert self.W.shape == self.V.shape, msg + errW = "W and V are [numClasses*totalNodes, projectionDimension]" + assert self.W.shape[0] == self.numClasses * self.totalNodes, errW + assert self.W.shape[1] == self.projectionDimension, errW + errZ = "Z is [projectionDimension, dataDimension]" + assert self.Z.shape[0] == self.projectionDimension, errZ + assert self.Z.shape[1] == self.dataDimension, errZ + errT = "T is [internalNodes, projectionDimension]" + assert self.T.shape[0] == self.internalNodes, errT + assert self.T.shape[1] == self.projectionDimension, errT + assert int(self.numClasses) > 0, "numClasses should be > 1" + msg = "# of features in data should be > 0" + assert int(self.dataDimension) > 0, msg + msg = "Projection should be > 0 dims" + assert int(self.projectionDimension) > 0, msg + msg = "treeDepth should be >= 0" + assert int(self.treeDepth) >= 0, msg diff --git a/tf2.0/edgeml/graph/protoNN.py b/tf2.0/edgeml/graph/protoNN.py new file mode 100644 index 000000000..2ea5b85ff --- /dev/null +++ b/tf2.0/edgeml/graph/protoNN.py @@ -0,0 +1,191 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import numpy as np +import tensorflow as tf + + +class ProtoNN: + def __init__(self, inputDimension, projectionDimension, numPrototypes, + numOutputLabels, gamma, + W = None, B = None, Z = None): + ''' + Forward computation graph for ProtoNN. + + inputDimension: Input data dimension or feature dimension. + projectionDimension: hyperparameter + numPrototypes: hyperparameter + numOutputLabels: The number of output labels or classes + W, B, Z: Numpy matrices that can be used to initialize + projection matrix(W), prototype matrix (B) and prototype labels + matrix (B). + Expected Dimensions: + W inputDimension (d) x projectionDimension (d_cap) + B projectionDimension (d_cap) x numPrototypes (m) + Z numOutputLabels (L) x numPrototypes (m) + ''' + with tf.compat.v1.name_scope('protoNN') as ns: + self.__nscope = ns + self.__d = inputDimension + self.__d_cap = projectionDimension + self.__m = numPrototypes + self.__L = numOutputLabels + + self.__inW = W + self.__inB = B + self.__inZ = Z + self.__inGamma = gamma + self.W, self.B, self.Z = None, None, None + self.gamma = None + + self.__validInit = False + self.__initWBZ() + self.__initGamma() + self.__validateInit() + self.protoNNOut = None + self.predictions = None + self.accuracy = None + + def __validateInit(self): + self.__validInit = False + errmsg = "Dimensions mismatch! Should be W[d, d_cap]" + errmsg += ", B[d_cap, m] and Z[L, m]" + d, d_cap, m, L, _ = self.getHyperParams() + assert self.W.shape[0] == d, errmsg + assert self.W.shape[1] == d_cap, errmsg + assert self.B.shape[0] == d_cap, errmsg + assert self.B.shape[1] == m, errmsg + assert self.Z.shape[0] == L, errmsg + assert self.Z.shape[1] == m, errmsg + self.__validInit = True + + def __initWBZ(self): + with tf.compat.v1.name_scope(self.__nscope): + W = self.__inW + if W is None: + W = tf.compat.v1.initializers.random_normal() + W = W([self.__d, self.__d_cap]) + self.W = tf.Variable(W, name='W', dtype=tf.float32) + + B = self.__inB + if B is None: + B = tf.compat.v1.initializers.random_uniform() + B = B([self.__d_cap, self.__m]) + self.B = tf.Variable(B, name='B', dtype=tf.float32) + + Z = self.__inZ + if Z is None: + Z = tf.compat.v1.initializers.random_normal() + Z = Z([self.__L, self.__m]) + Z = tf.Variable(Z, name='Z', dtype=tf.float32) + self.Z = Z + return self.W, self.B, self.Z + + def __initGamma(self): + with tf.compat.v1.name_scope(self.__nscope): + gamma = self.__inGamma + self.gamma = tf.constant(gamma, name='gamma') + + def getHyperParams(self): + ''' + Returns the model hyperparameters: + [inputDimension, projectionDimension, + numPrototypes, numOutputLabels, gamma] + ''' + d = self.__d + dcap = self.__d_cap + m = self.__m + L = self.__L + return d, dcap, m, L, self.gamma + + def getModelMatrices(self): + ''' + Returns Tensorflow tensors of the model matrices, which + can then be evaluated to obtain corresponding numpy arrays. + + These can then be exported as part of other implementations of + ProtonNN, for instance a C++ implementation or pure python + implementation. + Returns + [ProjectionMatrix (W), prototypeMatrix (B), + prototypeLabelsMatrix (Z), gamma] + ''' + return self.W, self.B, self.Z, self.gamma + + def __call__(self, X, Y=None): + ''' + This method is responsible for construction of the forward computation + graph. The end point of the computation graph, or in other words the + output operator for the forward computation is returned. Additionally, + if the argument Y is provided, a classification accuracy operator with + Y as target will also be created. For this, Y is assumed to in one-hot + encoded format and the class with the maximum prediction score is + compared to the encoded class in Y. This accuracy operator is returned + by getAccuracyOp() method. If a different accuracyOp is required, it + can be defined by overriding the createAccOp(protoNNScoresOut, Y) + method. + + X: Input tensor or placeholder of shape [-1, inputDimension] + Y: Optional tensor or placeholder for targets (labels or classes). + Expected shape is [-1, numOutputLabels]. + returns: The forward computation outputs, self.protoNNOut + ''' + # This should never execute + assert self.__validInit is True, "Initialization failed!" + if self.protoNNOut is not None: + return self.protoNNOut + + W, B, Z, gamma = self.W, self.B, self.Z, self.gamma + with tf.compat.v1.name_scope(self.__nscope): + WX = tf.matmul(X, W) + # Convert WX to tensor so that broadcasting can work + dim = [-1, WX.shape.as_list()[1], 1] + WX = tf.reshape(WX, dim) + dim = [1, B.shape.as_list()[0], -1] + B_ = tf.reshape(B, dim) + l2sim = B_ - WX + l2sim = tf.pow(l2sim, 2) + l2sim = tf.reduce_sum(input_tensor=l2sim, axis=1, keepdims=True) + self.l2sim = l2sim + gammal2sim = (-1 * gamma * gamma) * l2sim + M = tf.exp(gammal2sim) + dim = [1] + Z.shape.as_list() + Z_ = tf.reshape(Z, dim) + y = tf.multiply(Z_, M) + y = tf.reduce_sum(input_tensor=y, axis=2, name='protoNNScoreOut') + self.protoNNOut = y + self.predictions = tf.argmax(input=y, axis=1, name='protoNNPredictions') + if Y is not None: + self.createAccOp(self.protoNNOut, Y) + return y + + def createAccOp(self, outputs, target): + ''' + Define an accuracy operation on ProtoNN's output scores and targets. + Here a simple classification accuracy operator is defined. More + complicated operators (for multiple label problems and so forth) can be + defined by overriding this method + ''' + assert self.predictions is not None + target = tf.argmax(input=target, axis=1) + correctPrediction = tf.equal(self.predictions, target) + acc = tf.reduce_mean(input_tensor=tf.cast(correctPrediction, tf.float32), + name='protoNNAccuracy') + self.accuracy = acc + + def getPredictionsOp(self): + ''' + The predictions operator is defined as argmax(protoNNScores) for each + prediction. + ''' + return self.predictions + + def getAccuracyOp(self): + ''' + returns accuracyOp as defined by createAccOp. It defaults to + multi-class classification accuracy. + ''' + msg = "Accuracy operator not defined in graph. Did you provide Y as an" + msg += " argument to _call_?" + assert self.accuracy is not None, msg + return self.accuracy diff --git a/tf2.0/edgeml/trainer/__init__.py b/tf2.0/edgeml/trainer/__init__.py new file mode 100644 index 000000000..3d7ff8299 --- /dev/null +++ b/tf2.0/edgeml/trainer/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. diff --git a/tf2.0/edgeml/trainer/bonsaiTrainer.py b/tf2.0/edgeml/trainer/bonsaiTrainer.py new file mode 100644 index 000000000..2e86663ae --- /dev/null +++ b/tf2.0/edgeml/trainer/bonsaiTrainer.py @@ -0,0 +1,560 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import tensorflow as tf +import edgeml.utils as utils +import numpy as np +import os +import sys + + +class BonsaiTrainer: + + def __init__(self, bonsaiObj, lW, lT, lV, lZ, sW, sT, sV, sZ, + learningRate, X, Y, useMCHLoss=False, outFile=None, regLoss='huber'): + ''' + bonsaiObj - Initialised Bonsai Object and Graph + lW, lT, lV and lZ are regularisers to Bonsai Params + sW, sT, sV and sZ are sparsity factors to Bonsai Params + learningRate - learningRate fro optimizer + X is the Data Placeholder - Dims [_, dataDimension] + Y - Label placeholder for loss computation + useMCHLoss - For choice between HingeLoss vs CrossEntropy + useMCHLoss - True - MultiClass - multiClassHingeLoss + useMCHLoss - False - MultiClass - crossEntropyLoss + ''' + + self.bonsaiObj = bonsaiObj + self.regressionLoss = regLoss + + self.lW = lW + self.lV = lV + self.lT = lT + self.lZ = lZ + + self.sW = sW + self.sV = sV + self.sT = sT + self.sZ = sZ + + self.Y = Y + self.X = X + + self.useMCHLoss = useMCHLoss + + if outFile is not None: + print("Outfile : ", outFile) + self.outFile = open(outFile, 'w') + else: + self.outFile = sys.stdout + + self.learningRate = learningRate + + self.assertInit() + + self.sigmaI = tf.compat.v1.placeholder(tf.float32, name='sigmaI') + + self.score, self.X_ = self.bonsaiObj(self.X, self.sigmaI) + + self.loss, self.marginLoss, self.regLoss = self.lossGraph() + + self.trainStep = self.trainGraph() + ''' + self.accuracy -> 'MAE' for Regression. + self.accuracy -> 'Accuracy' for Classification. + ''' + self.accuracy = self.accuracyGraph() + self.prediction = self.bonsaiObj.getPrediction() + + if self.sW > 0.99 and self.sV > 0.99 and self.sZ > 0.99 and self.sT > 0.99: + self.isDenseTraining = True + else: + self.isDenseTraining = False + + self.hardThrsd() + self.sparseTraining() + + def lossGraph(self): + ''' + Loss Graph for given Bonsai Obj + ''' + self.regLoss = 0.5 * (self.lZ * tf.square(tf.norm(tensor=self.bonsaiObj.Z)) + + self.lW * tf.square(tf.norm(tensor=self.bonsaiObj.W)) + + self.lV * tf.square(tf.norm(tensor=self.bonsaiObj.V)) + + self.lT * tf.square(tf.norm(tensor=self.bonsaiObj.T))) + + # Loss functions for classification. + if (self.bonsaiObj.isRegression is False): + if (self.bonsaiObj.numClasses > 2): + if self.useMCHLoss is True: + self.batch_th = tf.compat.v1.placeholder(tf.int64, name='batch_th') + self.marginLoss = utils.multiClassHingeLoss( + tf.transpose(a=self.score), self.Y, + self.batch_th) + else: + self.marginLoss = utils.crossEntropyLoss( + tf.transpose(a=self.score), self.Y) + self.loss = self.marginLoss + self.regLoss + else: + self.marginLoss = tf.reduce_mean(input_tensor=tf.nn.relu( + 1.0 - (2 * self.Y - 1) * tf.transpose(a=self.score))) + self.loss = self.marginLoss + self.regLoss + + # Loss functions for regression. + elif (self.bonsaiObj.isRegression is True): + if(self.regressionLoss == 'huber'): + # Use of Huber Loss , because it is more robust to outliers. + self.marginLoss = tf.compat.v1.losses.huber_loss( + self.Y, tf.transpose(a=self.score)) + self.loss = self.marginLoss + self.regLoss + elif (self.regressionLoss == 'l2'): + # L2 loss function. + self.marginLoss = tf.nn.l2_loss( + self.Y - tf.transpose(a=self.score)) + self.loss = self.marginLoss + self.regLoss + + return self.loss, self.marginLoss, self.regLoss + + def trainGraph(self): + ''' + Train Graph for the loss generated by Bonsai + ''' + self.bonsaiObj.TrainStep = tf.compat.v1.train.AdamOptimizer( + self.learningRate).minimize(self.loss) + + return self.bonsaiObj.TrainStep + + def accuracyGraph(self): + ''' + Accuracy Graph to evaluate accuracy when needed + ''' + if(self.bonsaiObj.isRegression is False): + if (self.bonsaiObj.numClasses > 2): + correctPrediction = tf.equal( + tf.argmax(input=tf.transpose(a=self.score), axis=1), tf.argmax(input=self.Y, axis=1)) + self.accuracy = tf.reduce_mean( + input_tensor=tf.cast(correctPrediction, tf.float32)) + else: + y_ = self.Y * 2 - 1 + correctPrediction = tf.multiply(tf.transpose(a=self.score), y_) + correctPrediction = tf.nn.relu(correctPrediction) + correctPrediction = tf.math.ceil(tf.tanh(correctPrediction)) + self.accuracy = tf.reduce_mean( + input_tensor=tf.cast(correctPrediction, tf.float32)) + + elif (self.bonsaiObj.isRegression is True): + # Accuracy for regression , in terms of mean absolute error. + self.accuracy = utils.mean_absolute_error(tf.reshape( + self.score, [-1, 1]), tf.reshape(self.Y, [-1, 1])) + return self.accuracy + + def hardThrsd(self): + ''' + Set up for hard Thresholding Functionality + ''' + self.__Wth = tf.compat.v1.placeholder(tf.float32, name='Wth') + self.__Vth = tf.compat.v1.placeholder(tf.float32, name='Vth') + self.__Zth = tf.compat.v1.placeholder(tf.float32, name='Zth') + self.__Tth = tf.compat.v1.placeholder(tf.float32, name='Tth') + + self.__Woph = self.bonsaiObj.W.assign(self.__Wth) + self.__Voph = self.bonsaiObj.V.assign(self.__Vth) + self.__Toph = self.bonsaiObj.T.assign(self.__Tth) + self.__Zoph = self.bonsaiObj.Z.assign(self.__Zth) + + self.hardThresholdGroup = tf.group( + self.__Woph, self.__Voph, self.__Toph, self.__Zoph) + + def sparseTraining(self): + ''' + Set up for Sparse Retraining Functionality + ''' + self.__Wops = self.bonsaiObj.W.assign(self.__Wth) + self.__Vops = self.bonsaiObj.V.assign(self.__Vth) + self.__Zops = self.bonsaiObj.Z.assign(self.__Zth) + self.__Tops = self.bonsaiObj.T.assign(self.__Tth) + + self.sparseRetrainGroup = tf.group( + self.__Wops, self.__Vops, self.__Tops, self.__Zops) + + def runHardThrsd(self, sess): + ''' + Function to run the IHT routine on Bonsai Obj + ''' + currW = self.bonsaiObj.W.eval() + currV = self.bonsaiObj.V.eval() + currZ = self.bonsaiObj.Z.eval() + currT = self.bonsaiObj.T.eval() + + self.__thrsdW = utils.hardThreshold(currW, self.sW) + self.__thrsdV = utils.hardThreshold(currV, self.sV) + self.__thrsdZ = utils.hardThreshold(currZ, self.sZ) + self.__thrsdT = utils.hardThreshold(currT, self.sT) + + fd_thrsd = {self.__Wth: self.__thrsdW, self.__Vth: self.__thrsdV, + self.__Zth: self.__thrsdZ, self.__Tth: self.__thrsdT} + sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd) + + def runSparseTraining(self, sess): + ''' + Function to run the Sparse Retraining routine on Bonsai Obj + ''' + currW = self.bonsaiObj.W.eval() + currV = self.bonsaiObj.V.eval() + currZ = self.bonsaiObj.Z.eval() + currT = self.bonsaiObj.T.eval() + + newW = utils.copySupport(self.__thrsdW, currW) + newV = utils.copySupport(self.__thrsdV, currV) + newZ = utils.copySupport(self.__thrsdZ, currZ) + newT = utils.copySupport(self.__thrsdT, currT) + + fd_st = {self.__Wth: newW, self.__Vth: newV, + self.__Zth: newZ, self.__Tth: newT} + sess.run(self.sparseRetrainGroup, feed_dict=fd_st) + + def assertInit(self): + err = "sparsity must be between 0 and 1" + assert self.sW >= 0 and self.sW <= 1, "W " + err + assert self.sV >= 0 and self.sV <= 1, "V " + err + assert self.sZ >= 0 and self.sZ <= 1, "Z " + err + assert self.sT >= 0 and self.sT <= 1, "T " + err + errMsg = "Dimension Mismatch, Y has to be [_, " + \ + str(self.bonsaiObj.numClasses) + "]" + errCont = " numClasses are 1 in case of Binary case by design" + assert (len(self.Y.shape) == 2 and + self.Y.shape[1] == self.bonsaiObj.numClasses), errMsg + errCont + + def saveParams(self, currDir): + ''' + Function to save Parameter matrices into a given folder + ''' + paramDir = currDir + '/' + np.save(paramDir + "W.npy", self.bonsaiObj.W.eval()) + np.save(paramDir + "V.npy", self.bonsaiObj.V.eval()) + np.save(paramDir + "T.npy", self.bonsaiObj.T.eval()) + np.save(paramDir + "Z.npy", self.bonsaiObj.Z.eval()) + hyperParamDict = {'dataDim': self.bonsaiObj.dataDimension, + 'projDim': self.bonsaiObj.projectionDimension, + 'numClasses': self.bonsaiObj.numClasses, + 'depth': self.bonsaiObj.treeDepth, + 'sigma': self.bonsaiObj.sigma} + hyperParamFile = paramDir + 'hyperParam.npy' + np.save(hyperParamFile, hyperParamDict) + + def saveParamsForSeeDot(self, currDir): + ''' + Function to save Parameter matrices into a given folder for SeeDot compiler + ''' + seeDotDir = currDir + '/SeeDot/' + + if os.path.isdir(seeDotDir) is False: + try: + os.mkdir(seeDotDir) + except OSError: + print("Creation of the directory %s failed" % + seeDotDir) + + np.savetxt(seeDotDir + "W", + utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.W.eval(), + self.bonsaiObj.numClasses, + self.bonsaiObj.totalNodes), + delimiter="\t") + np.savetxt(seeDotDir + "V", + utils.restructreMatrixBonsaiSeeDot(self.bonsaiObj.V.eval(), + self.bonsaiObj.numClasses, + self.bonsaiObj.totalNodes), + delimiter="\t") + np.savetxt(seeDotDir + "T", self.bonsaiObj.T.eval(), delimiter="\t") + np.savetxt(seeDotDir + "Z", self.bonsaiObj.Z.eval(), delimiter="\t") + np.savetxt(seeDotDir + "Sigma", + np.array([self.bonsaiObj.sigma]), delimiter="\t") + + def loadModel(self, currDir): + ''' + Load the Saved model and load it to the model using constructor + Returns two dict one for params and other for hyperParams + ''' + paramDir = currDir + '/' + paramDict = {} + paramDict['W'] = np.load(paramDir + "W.npy") + paramDict['V'] = np.load(paramDir + "V.npy") + paramDict['T'] = np.load(paramDir + "T.npy") + paramDict['Z'] = np.load(paramDir + "Z.npy") + hyperParamDict = np.load(paramDir + "hyperParam.npy").item() + return paramDict, hyperParamDict + + # Function to get aimed model size + def getModelSize(self): + ''' + Function to get aimed model size + ''' + nnzZ, sizeZ, sparseZ = utils.countnnZ(self.bonsaiObj.Z, self.sZ) + nnzW, sizeW, sparseW = utils.countnnZ(self.bonsaiObj.W, self.sW) + nnzV, sizeV, sparseV = utils.countnnZ(self.bonsaiObj.V, self.sV) + nnzT, sizeT, sparseT = utils.countnnZ(self.bonsaiObj.T, self.sT) + + totalnnZ = (nnzZ + nnzT + nnzV + nnzW) + totalSize = (sizeZ + sizeW + sizeV + sizeT) + hasSparse = (sparseW or sparseV or sparseT or sparseZ) + return totalnnZ, totalSize, hasSparse + + def train(self, batchSize, totalEpochs, sess, + Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir): + ''' + The Dense - IHT - Sparse Retrain Routine for Bonsai Training + ''' + resultFile = open(dataDir + '/TFBonsaiResults.txt', 'a+') + numIters = Xtrain.shape[0] / batchSize + + totalBatches = numIters * totalEpochs + + bonsaiObjSigmaI = 1 + + counter = 0 + if self.bonsaiObj.numClasses > 2: + trimlevel = 15 + else: + trimlevel = 5 + ihtDone = 0 + if (self.bonsaiObj.isRegression is True): + maxTestAcc = 100000007 + else: + maxTestAcc = -10000 + if self.isDenseTraining is True: + ihtDone = 1 + bonsaiObjSigmaI = 1 + itersInPhase = 0 + + header = '*' * 20 + for i in range(totalEpochs): + print("\nEpoch Number: " + str(i), file=self.outFile) + + ''' + trainAcc -> For Regression, it is 'Mean Absolute Error'. + trainAcc -> For Classification, it is 'Accuracy'. + ''' + trainAcc = 0.0 + trainLoss = 0.0 + + numIters = int(numIters) + for j in range(numIters): + + if counter == 0: + msg = " Dense Training Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + + # Updating the indicator sigma + if ((counter == 0) or (counter == int(totalBatches / 3.0)) or + (counter == int(2 * totalBatches / 3.0))) and (self.isDenseTraining is False): + bonsaiObjSigmaI = 1 + itersInPhase = 0 + + elif (itersInPhase % 100 == 0): + indices = np.random.choice(Xtrain.shape[0], 100) + batchX = Xtrain[indices, :] + batchY = Ytrain[indices, :] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + _feed_dict = {self.X: batchX} + Xcapeval = self.X_.eval(feed_dict=_feed_dict) + Teval = self.bonsaiObj.T.eval() + + sum_tr = 0.0 + for k in range(0, self.bonsaiObj.internalNodes): + sum_tr += (np.sum(np.abs(np.dot(Teval[k], Xcapeval)))) + + if(self.bonsaiObj.internalNodes > 0): + sum_tr /= (100 * self.bonsaiObj.internalNodes) + sum_tr = 0.1 / sum_tr + else: + sum_tr = 0.1 + sum_tr = min( + 1000, sum_tr * (2**(float(itersInPhase) / + (float(totalBatches) / 30.0)))) + + bonsaiObjSigmaI = sum_tr + + itersInPhase += 1 + batchX = Xtrain[j * batchSize:(j + 1) * batchSize] + batchY = Ytrain[j * batchSize:(j + 1) * batchSize] + batchY = np.reshape( + batchY, [-1, self.bonsaiObj.numClasses]) + + if self.bonsaiObj.numClasses > 2: + if self.useMCHLoss is True: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.batch_th: batchY.shape[0], + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: batchX, self.Y: batchY, + self.sigmaI: bonsaiObjSigmaI} + + # Mini-batch training + _, batchLoss, batchAcc = sess.run( + [self.trainStep, self.loss, self.accuracy], + feed_dict=_feed_dict) + + # Classification. + if (self.bonsaiObj.isRegression is False): + trainAcc += batchAcc + trainLoss += batchLoss + # Regression. + else: + trainAcc += np.mean(batchAcc) + trainLoss += np.mean(batchLoss) + + # Training routine involving IHT and sparse retraining + if (counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel == 0 and + self.isDenseTraining is False): + self.runHardThrsd(sess) + if ihtDone == 0: + msg = " IHT Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + ihtDone = 1 + elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel != 0 and + self.isDenseTraining is False) or + (counter >= int(2 * totalBatches / 3.0) and + self.isDenseTraining is False)): + self.runSparseTraining(sess) + if counter == int(2 * totalBatches / 3.0): + msg = " Sparse Retraining Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + counter += 1 + try: + if (self.bonsaiObj.isRegression is True): + print("\nRegression Train Loss: " + str(trainLoss / numIters) + + "\nTraining MAE (Regression): " + + str(trainAcc / numIters), + file=self.outFile) + else: + print("\nClassification Train Loss: " + str(trainLoss / numIters) + + "\nTraining accuracy (Classification): " + + str(trainAcc / numIters), + file=self.outFile) + except: + continue + + oldSigmaI = bonsaiObjSigmaI + bonsaiObjSigmaI = 1e9 + + if self.bonsaiObj.numClasses > 2: + if self.useMCHLoss is True: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.batch_th: Ytest.shape[0], + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.sigmaI: bonsaiObjSigmaI} + else: + _feed_dict = {self.X: Xtest, self.Y: Ytest, + self.sigmaI: bonsaiObjSigmaI} + + # This helps in direct testing instead of extracting the model out + + testAcc, testLoss, regTestLoss, pred = sess.run( + [self.accuracy, self.loss, self.regLoss, self.prediction], feed_dict=_feed_dict) + + if ihtDone == 0: + if (self.bonsaiObj.isRegression is False): + maxTestAcc = -10000 + maxTestAccEpoch = i + elif (self.bonsaiObj.isRegression is True): + maxTestAcc = testAcc + maxTestAccEpoch = i + + else: + if (self.bonsaiObj.isRegression is False): + if maxTestAcc <= testAcc: + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + self.saveParamsForSeeDot(currDir) + elif (self.bonsaiObj.isRegression is True): + print("Minimum Training MAE : ", np.mean(maxTestAcc)) + if maxTestAcc >= testAcc: + # For regression , we're more interested in the minimum + # MAE. + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + self.saveParamsForSeeDot(currDir) + + if (self.bonsaiObj.isRegression is True): + print("Testing MAE %g" % np.mean(testAcc), file=self.outFile) + else: + print("Test accuracy %g" % np.mean(testAcc), file=self.outFile) + + if (self.bonsaiObj.isRegression is True): + testAcc = np.mean(testAcc) + else: + testAcc = testAcc + maxTestAcc = maxTestAcc + + print("MarginLoss + RegLoss: " + str(testLoss - regTestLoss) + + " + " + str(regTestLoss) + " = " + str(testLoss) + "\n", + file=self.outFile) + self.outFile.flush() + + bonsaiObjSigmaI = oldSigmaI + + # sigmaI has to be set to infinity to ensure + # only a single path is used in inference + bonsaiObjSigmaI = 1e9 + print("\nNon-Zero : " + str(self.getModelSize()[0]) + " Model Size: " + + str(float(self.getModelSize()[1]) / 1024.0) + " KB hasSparse: " + + str(self.getModelSize()[2]) + "\n", file=self.outFile) + + if (self.bonsaiObj.isRegression is True): + maxTestAcc = np.mean(maxTestAcc) + + if (self.bonsaiObj.isRegression is True): + print("For Regression, Minimum MAE at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " MAE: " + str(testAcc), file=self.outFile) + + resultFile.write("MinTestMAE: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + + elif (self.bonsaiObj.isRegression is False): + print("For Classification, Maximum Test accuracy at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " Accuracy: " + str(testAcc), file=self.outFile) + + resultFile.write("MaxTestAcc: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + print("The Model Directory: " + currDir + "\n") + + resultFile.close() + self.outFile.flush() + + if self.outFile is not sys.stdout: + self.outFile.close() diff --git a/tf2.0/edgeml/trainer/fastTrainer.py b/tf2.0/edgeml/trainer/fastTrainer.py new file mode 100644 index 000000000..bb1f51b10 --- /dev/null +++ b/tf2.0/edgeml/trainer/fastTrainer.py @@ -0,0 +1,527 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import os +import sys +import tensorflow as tf +import edgeml.utils as utils +import numpy as np +from tensorflow.python.framework import graph_util + + +class FastTrainer: + + def __init__(self, FastObj, X, Y, sW=1.0, sU=1.0, learningRate=0.01, + outFile=None): + ''' + FastObj - Can be either FastRNN or FastGRNN with proper initialisations + sW and sU are the sparsity factors for Fast parameters + X is the Data Placeholder - Dims [_, timesteps, input_dims] + Y is the label placeholder for loss computation - Dims [_, num_classes] + batchSize is the batchSize + learningRate is the initial learning rate + ''' + + self.FastObj = FastObj + self.history = [] + + self.sW = sW + self.sU = sU + + self.Y = Y + self.X = X + + self.numClasses = int(self.Y.shape[1]) + self.timeSteps = int(self.X.shape[1]) + self.inputDims = int(self.X.shape[2]) + + self.learningRate = learningRate + + if outFile is not None: + self.outFile = open(outFile, 'w') + else: + self.outFile = sys.stdout + + self.lr = tf.compat.v1.placeholder("float", name="lr") + + self.logits, self.finalHiddenState, self.predictions = self.computeGraph() + + self.lossOp = self.lossGraph(self.logits, self.Y) + self.trainOp = self.trainGraph(self.lossOp, self.lr) + + self.correctPredictions, self.accuracy = self.accuracyGraph( + self.predictions, self.Y) + + self.numMatrices = self.FastObj.num_weight_matrices + self.totalMatrices = self.numMatrices[0] + self.numMatrices[1] + + self.FastParams = self.FastObj.getVars() + + if self.sW > 0.99 and self.sU > 0.99: + self.isDenseTraining = True + else: + self.isDenseTraining = False + + self.hardThrsdGraph() + self.sparseTrainingGraph() + + def RNN(self, x, timeSteps, FastObj): + ''' + Unrolls and adds linear classifier + ''' + x = tf.unstack(x, timeSteps, 1) + outputs, states = tf.compat.v1.nn.static_rnn(FastObj, x, dtype=tf.float32) + return outputs[-1] + + def computeGraph(self): + ''' + Compute graph to unroll and predict on the FastObj + ''' + finalHiddenState = self.RNN(self.X, self.timeSteps, self.FastObj) + + logits = self.classifier(finalHiddenState) + predictions = tf.nn.softmax(logits, name='predictions') + + return logits, finalHiddenState, predictions + + def classifier(self, feats): + ''' + Can be raplaced by any classifier + TODO: Make this a separate class if needed + ''' + self.FC = tf.Variable(tf.random.normal( + [self.FastObj.output_size, self.numClasses]), name='FC') + self.FCbias = tf.Variable(tf.random.normal( + [self.numClasses]), name='FCbias') + + return tf.matmul(feats, self.FC) + self.FCbias + + def lossGraph(self, logits, Y): + ''' + Loss Graph for given FastObj + ''' + lossOp = utils.crossEntropyLoss(logits, Y) + return lossOp + + def trainGraph(self, lossOp, lr): + ''' + Train Graph for the loss generated by Bonsai + ''' + optimizer = tf.compat.v1.train.AdamOptimizer(lr) + trainOp = optimizer.minimize(lossOp) + return trainOp + + def accuracyGraph(self, predictions, Y): + ''' + Accuracy Graph to evaluate accuracy when needed + ''' + correctPredictions = tf.equal( + tf.argmax(input=predictions, axis=1), tf.argmax(input=Y, axis=1)) + accuracy = tf.reduce_mean(input_tensor=tf.cast(correctPredictions, tf.float32)) + return correctPredictions, accuracy + + def assertInit(self): + err = "sparsity must be between 0 and 1" + assert self.sW >= 0 and self.sW <= 1, "W " + err + assert self.sU >= 0 and self.sU <= 1, "U " + err + + def hardThrsdGraph(self): + ''' + Set up for hard Thresholding Functionality + ''' + self.paramPlaceholders = [] + self.htOps = [] + for i in range(0, self.numMatrices[0]): + self.paramPlaceholders.append(tf.compat.v1.placeholder( + tf.float32, name="Wth_" + str(i))) + for i in range(self.numMatrices[0], self.totalMatrices): + self.paramPlaceholders.append(tf.compat.v1.placeholder( + tf.float32, name="Uth_" + str(i))) + + for i in range(0, self.numMatrices[0]): + self.htOps.append( + self.FastParams[i].assign(self.paramPlaceholders[i])) + for i in range(self.numMatrices[0], self.totalMatrices): + self.htOps.append( + self.FastParams[i].assign(self.paramPlaceholders[i])) + + self.hardThresholdGroup = tf.group(*self.htOps) + + def sparseTrainingGraph(self): + ''' + Set up for Sparse Retraining Functionality + ''' + self.stOps = [] + + for i in range(0, self.numMatrices[0]): + self.stOps.append( + self.FastParams[i].assign(self.paramPlaceholders[i])) + for i in range(self.numMatrices[0], self.totalMatrices): + self.stOps.append( + self.FastParams[i].assign(self.paramPlaceholders[i])) + + self.sparseRetrainGroup = tf.group(*self.stOps) + + def runHardThrsd(self, sess): + ''' + Function to run the IHT routine on FastObj + ''' + self.thrsdParams = [] + for i in range(0, self.numMatrices[0]): + self.thrsdParams.append( + utils.hardThreshold(self.FastParams[i].eval(), self.sW)) + for i in range(self.numMatrices[0], self.totalMatrices): + self.thrsdParams.append( + utils.hardThreshold(self.FastParams[i].eval(), self.sU)) + + fd_thrsd = {} + for i in range(0, self.totalMatrices): + fd_thrsd[self.paramPlaceholders[i]] = self.thrsdParams[i] + sess.run(self.hardThresholdGroup, feed_dict=fd_thrsd) + + def runSparseTraining(self, sess): + ''' + Function to run the Sparse Retraining routine on FastObj + ''' + self.reTrainParams = [] + for i in range(0, self.totalMatrices): + self.reTrainParams.append( + utils.copySupport(self.thrsdParams[i], self.FastParams[i].eval())) + + fd_st = {} + for i in range(0, self.totalMatrices): + fd_st[self.paramPlaceholders[i]] = self.reTrainParams[i] + sess.run(self.sparseRetrainGroup, feed_dict=fd_st) + + def getModelSize(self): + ''' + Function to get aimed model size + ''' + totalnnZ = 0 + totalSize = 0 + hasSparse = False + for i in range(0, self.numMatrices[0]): + nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sW) + totalnnZ += nnz + totalSize += size + hasSparse = hasSparse or sparseFlag + + for i in range(self.numMatrices[0], self.totalMatrices): + nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], self.sU) + totalnnZ += nnz + totalSize += size + hasSparse = hasSparse or sparseFlag + for i in range(self.totalMatrices, len(self.FastParams)): + nnz, size, sparseFlag = utils.countnnZ(self.FastParams[i], 1.0) + totalnnZ += nnz + totalSize += size + hasSparse = hasSparse or sparseFlag + + # Replace this with classifier class call + nnz, size, sparseFlag = utils.countnnZ(self.FC, 1.0) + totalnnZ += nnz + totalSize += size + hasSparse = hasSparse or sparseFlag + + nnz, size, sparseFlag = utils.countnnZ(self.FCbias, 1.0) + totalnnZ += nnz + totalSize += size + hasSparse = hasSparse or sparseFlag + + return totalnnZ, totalSize, hasSparse + + def saveParams(self, currDir): + ''' + Function to save Parameter matrices + ''' + if self.numMatrices[0] == 1: + np.save(os.path.join(currDir, "W.npy"), self.FastParams[0].eval()) + elif self.FastObj.wRank is None: + if self.numMatrices[0] == 2: + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[1].eval()) + if self.numMatrices[0] == 3: + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "W3.npy"), + self.FastParams[2].eval()) + if self.numMatrices[0] == 4: + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "W3.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "W4.npy"), + self.FastParams[3].eval()) + elif self.FastObj.wRank is not None: + if self.numMatrices[0] == 3: + np.save(os.path.join(currDir, "W.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[2].eval()) + if self.numMatrices[0] == 4: + np.save(os.path.join(currDir, "W.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "W3.npy"), + self.FastParams[3].eval()) + if self.numMatrices[0] == 5: + np.save(os.path.join(currDir, "W.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "W1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "W2.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "W3.npy"), + self.FastParams[3].eval()) + np.save(os.path.join(currDir, "W4.npy"), + self.FastParams[4].eval()) + + if self.numMatrices[1] == 1: + np.save(os.path.join(currDir, "U.npy"), self.FastParams[0].eval()) + elif self.FastObj.uRank is None: + if self.numMatrices[1] == 2: + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[1].eval()) + if self.numMatrices[1] == 3: + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "U3.npy"), + self.FastParams[2].eval()) + if self.numMatrices[1] == 4: + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "U3.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "U4.npy"), + self.FastParams[3].eval()) + elif self.FastObj.uRank is not None: + if self.numMatrices[1] == 3: + np.save(os.path.join(currDir, "U.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[2].eval()) + if self.numMatrices[1] == 4: + np.save(os.path.join(currDir, "U.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "U3.npy"), + self.FastParams[3].eval()) + if self.numMatrices[1] == 5: + np.save(os.path.join(currDir, "U.npy"), + self.FastParams[0].eval()) + np.save(os.path.join(currDir, "U1.npy"), + self.FastParams[1].eval()) + np.save(os.path.join(currDir, "U2.npy"), + self.FastParams[2].eval()) + np.save(os.path.join(currDir, "U3.npy"), + self.FastParams[3].eval()) + np.save(os.path.join(currDir, "U4.npy"), + self.FastParams[4].eval()) + + if self.FastObj.cellType == "FastGRNN": + np.save(os.path.join(currDir, "Bg.npy"), + self.FastParams[self.totalMatrices].eval()) + np.save(os.path.join(currDir, "Bh.npy"), + self.FastParams[self.totalMatrices + 1].eval()) + np.save(os.path.join(currDir, "zeta.npy"), + self.FastParams[self.totalMatrices + 2].eval()) + np.save(os.path.join(currDir, "nu.npy"), + self.FastParams[self.totalMatrices + 3].eval()) + elif self.FastObj.cellType == "FastRNN": + np.save(os.path.join(currDir, "B.npy"), + self.FastParams[self.totalMatrices].eval()) + np.save(os.path.join(currDir, "alpha.npy"), self.FastParams[ + self.totalMatrices + 1].eval()) + np.save(os.path.join(currDir, "beta.npy"), + self.FastParams[self.totalMatrices + 2].eval()) + elif self.FastObj.cellType == "UGRNNLR": + np.save(os.path.join(currDir, "Bg.npy"), + self.FastParams[self.totalMatrices].eval()) + np.save(os.path.join(currDir, "Bh.npy"), + self.FastParams[self.totalMatrices + 1].eval()) + elif self.FastObj.cellType == "GRULR": + np.save(os.path.join(currDir, "Br.npy"), + self.FastParams[self.totalMatrices].eval()) + np.save(os.path.join(currDir, "Bg.npy"), + self.FastParams[self.totalMatrices + 1].eval()) + np.save(os.path.join(currDir, "Bh.npy"), + self.FastParams[self.totalMatrices + 2].eval()) + elif self.FastObj.cellType == "LSTMLR": + np.save(os.path.join(currDir, "Bf.npy"), + self.FastParams[self.totalMatrices].eval()) + np.save(os.path.join(currDir, "Bi.npy"), + self.FastParams[self.totalMatrices + 1].eval()) + np.save(os.path.join(currDir, "Bc.npy"), + self.FastParams[self.totalMatrices + 2].eval()) + np.save(os.path.join(currDir, "Bo.npy"), + self.FastParams[self.totalMatrices + 3].eval()) + + np.save(os.path.join(currDir, "FC.npy"), self.FC.eval()) + np.save(os.path.join(currDir, "FCbias.npy"), self.FCbias.eval()) + + def train(self, batchSize, totalEpochs, sess, + Xtrain, Xtest, Ytrain, Ytest, + decayStep, decayRate, dataDir, currDir): + ''' + The Dense - IHT - Sparse Retrain Routine for FastCell Training + ''' + fileName = str(self.FastObj.cellType) + 'Results.txt' + resultFile = open(os.path.join(dataDir, fileName), 'a+') + numIters = int(np.ceil(float(Xtrain.shape[0]) / float(batchSize))) + totalBatches = numIters * totalEpochs + + counter = 0 + trimlevel = 15 + ihtDone = 0 + maxTestAcc = -10000 + if self.isDenseTraining is True: + ihtDone = 1 + maxTestAcc = -10000 + header = '*' * 20 + + Xtest = Xtest.reshape((-1, self.timeSteps, self.inputDims)) + Xtrain = Xtrain.reshape((-1, self.timeSteps, self.inputDims)) + + self.history = [] + + for i in range(0, totalEpochs): + print("\nEpoch Number: " + str(i), file=self.outFile) + + if i % decayStep == 0 and i != 0: + self.learningRate = self.learningRate * decayRate + + shuffled = list(range(Xtrain.shape[0])) + np.random.shuffle(shuffled) + trainAcc = 0.0 + trainLoss = 0.0 + + numIters = int(numIters) + for j in range(0, numIters): + + if counter == 0: + msg = " Dense Training Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + + k = shuffled[j * batchSize:(j + 1) * batchSize] + batchX = Xtrain[k] + batchY = Ytrain[k] + + # Mini-batch training + _, batchLoss, batchAcc = sess.run([self.trainOp, self.lossOp, self.accuracy], feed_dict={ + self.X: batchX, self.Y: batchY, self.lr: self.learningRate}) + + trainAcc += batchAcc + trainLoss += batchLoss + + # Training routine involving IHT and sparse retraining + if (counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel == 0 and + self.isDenseTraining is False): + self.runHardThrsd(sess) + if ihtDone == 0: + msg = " IHT Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + ihtDone = 1 + elif ((ihtDone == 1 and counter >= int(totalBatches / 3.0) and + (counter < int(2 * totalBatches / 3.0)) and + counter % trimlevel != 0 and + self.isDenseTraining is False) or + (counter >= int(2 * totalBatches / 3.0) and + self.isDenseTraining is False)): + self.runSparseTraining(sess) + if counter == int(2 * totalBatches / 3.0): + msg = " Sprase Retraining Phase Started " + print("\n%s%s%s\n" % + (header, msg, header), file=self.outFile) + counter += 1 + + trainLoss /= numIters + trainAcc /= numIters + print("Train Loss: " + str(trainLoss) + + " Train Accuracy: " + str(trainAcc), + file=self.outFile) + + testAcc, testLoss = sess.run([self.accuracy, self.lossOp], feed_dict={ + self.X: Xtest, self.Y: Ytest}) + + self.history += [ + { + "epoch": i, + "trainAcc": trainAcc, + "trainLoss": trainLoss, + "testAcc": testAcc, + "testLoss": testLoss + } + ] + + if ihtDone == 0: + maxTestAcc = -10000 + maxTestAccEpoch = i + else: + if maxTestAcc <= testAcc: + maxTestAccEpoch = i + maxTestAcc = testAcc + self.saveParams(currDir) + + print("Test Loss: " + str(testLoss) + + " Test Accuracy: " + str(testAcc), file=self.outFile) + self.outFile.flush() + + print("\nMaximum Test accuracy at compressed" + + " model size(including early stopping): " + + str(maxTestAcc) + " at Epoch: " + + str(maxTestAccEpoch + 1) + "\nFinal Test" + + " Accuracy: " + str(testAcc), file=self.outFile) + print("\n\nNon-Zeros: " + str(self.getModelSize()[0]) + + " Model Size: " + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + "\n", + file=self.outFile) + + resultFile.write("MaxTestAcc: " + str(maxTestAcc) + + " at Epoch(totalEpochs): " + + str(maxTestAccEpoch + 1) + + "(" + str(totalEpochs) + ")" + " ModelSize: " + + str(float(self.getModelSize()[1]) / 1024.0) + + " KB hasSparse: " + str(self.getModelSize()[2]) + + " Param Directory: " + + str(os.path.abspath(currDir)) + "\n") + + print("The Model Directory: " + currDir + "\n") + + # output the tensorflow model + model_dir = os.path.join(currDir, "model") + os.makedirs(model_dir, exist_ok=True) + + resultFile.close() + self.outFile.flush() + if self.outFile is not sys.stdout: + self.outFile.close() + + def getAccuracyLog(self): + return self.history diff --git a/tf2.0/edgeml/trainer/protoNNTrainer.py b/tf2.0/edgeml/trainer/protoNNTrainer.py new file mode 100644 index 000000000..27de4d7de --- /dev/null +++ b/tf2.0/edgeml/trainer/protoNNTrainer.py @@ -0,0 +1,219 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import tensorflow as tf +import numpy as np +import sys +import edgeml.utils as utils + + +class ProtoNNTrainer: + def __init__(self, protoNNObj, regW, regB, regZ, + sparcityW, sparcityB, sparcityZ, + learningRate, X, Y, lossType='l2'): + ''' + A wrapper for the various techniques used for training ProtoNN. This + subsumes both the responsibility of loss graph construction and + performing training. The original training routine that is part of the + C++ implementation of EdgeML used iterative hard thresholding (IHT), + gamma estimation through median heuristic and other tricks for + training ProtoNN. This module implements the same in Tensorflow + and python. + + protoNNObj: An instance of ProtoNN class defining the forward + computation graph. The loss functions and training routines will be + attached to this instance. + regW, regB, regZ: Regularization constants for W, B, and + Z matrices of protoNN. + sparcityW, sparcityB, sparcityZ: Sparsity constraints + for W, B and Z matrices. A value between 0 (exclusive) and 1 + (inclusive) is expected. A value of 1 indicates dense training. + learningRate: Initial learning rate for ADAM optimizer. + X, Y : Placeholders for data and labels. + X [-1, featureDimension] + Y [-1, num Labels] + lossType: ['l2', 'xentropy'] + ''' + self.protoNNObj = protoNNObj + self.__regW = regW + self.__regB = regB + self.__regZ = regZ + self.__sW = sparcityW + self.__sB = sparcityB + self.__sZ = sparcityZ + self.__lR = learningRate + self.X = X + self.Y = Y + self.sparseTraining = True + if (sparcityW == 1.0) and (sparcityB == 1.0) and (sparcityZ == 1.0): + self.sparseTraining = False + print("Sparse training disabled.", file=sys.stderr) + # Define placeholders for sparse training + self.W_th = None + self.B_th = None + self.Z_th = None + self.__lossType = lossType + self.__validInit = False + self.__validInit = self.__validateInit() + self.__protoNNOut = protoNNObj(X, Y) + self.loss = self.__lossGraph() + self.trainStep = self.__trainGraph() + self.__hthOp = self.__getHardThresholdOp() + self.accuracy = protoNNObj.getAccuracyOp() + + def __validateInit(self): + self.__validInit = False + msg = "Sparsity value should be between" + msg += " 0 and 1 (both inclusive)." + assert self.__sW >= 0. and self.__sW <= 1., 'W:' + msg + assert self.__sB >= 0. and self.__sB <= 1., 'B:' + msg + assert self.__sZ >= 0. and self.__sZ <= 1., 'Z:' + msg + d, dcap, m, L, _ = self.protoNNObj.getHyperParams() + msg = 'Y should be of dimension [-1, num labels/classes]' + msg += ' specified as part of ProtoNN object.' + assert (len(self.Y.shape)) == 2, msg + assert (self.Y.shape[1] == L), msg + msg = 'X should be of dimension [-1, featureDimension]' + msg += ' specified as part of ProtoNN object.' + assert (len(self.X.shape) == 2), msg + assert (self.X.shape[1] == d), msg + self.__validInit = True + msg = 'Values can be \'l2\', or \'xentropy\'' + if self.__lossType not in ['l2', 'xentropy']: + raise ValueError(msg) + return True + + def __lossGraph(self): + pnnOut = self.__protoNNOut + l1, l2, l3 = self.__regW, self.__regB, self.__regZ + W, B, Z, _ = self.protoNNObj.getModelMatrices() + if self.__lossType == 'l2': + with tf.compat.v1.name_scope('protonn-l2-loss'): + loss_0 = tf.nn.l2_loss(self.Y - pnnOut) + reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B) + reg += l3 * tf.nn.l2_loss(Z) + loss = loss_0 + reg + elif self.__lossType == 'xentropy': + with tf.compat.v1.name_scope('protonn-xentropy-loss'): + loss_0 = tf.nn.softmax_cross_entropy_with_logits(logits=pnnOut, + labels=tf.stop_gradient(self.Y)) + loss_0 = tf.reduce_mean(input_tensor=loss_0) + reg = l1 * tf.nn.l2_loss(W) + l2 * tf.nn.l2_loss(B) + reg += l3 * tf.nn.l2_loss(Z) + loss = loss_0 + reg + return loss + + def __trainGraph(self): + with tf.compat.v1.name_scope('protonn-gradient-adam'): + trainStep = tf.compat.v1.train.AdamOptimizer(self.__lR) + trainStep = trainStep.minimize(self.loss) + return trainStep + + def __getHardThresholdOp(self): + W, B, Z, _ = self.protoNNObj.getModelMatrices() + self.W_th = tf.compat.v1.placeholder(tf.float32, name='W_th') + self.B_th = tf.compat.v1.placeholder(tf.float32, name='B_th') + self.Z_th = tf.compat.v1.placeholder(tf.float32, name='Z_th') + with tf.compat.v1.name_scope('hard-threshold-assignments'): + # hard_thrsd_W = W.assign(self.W_th) + # hard_thrsd_B = B.assign(self.B_th) + # hard_thrsd_Z = Z.assign(self.Z_th) + # Code changes for tf 1.11 + hard_thrsd_W = tf.compat.v1.assign(W, self.W_th) + hard_thrsd_B = tf.compat.v1.assign(B, self.B_th) + hard_thrsd_Z = tf.compat.v1.assign(Z, self.Z_th) + hard_thrsd_op = tf.group(hard_thrsd_W, hard_thrsd_B, hard_thrsd_Z) + return hard_thrsd_op + + def train(self, batchSize, totalEpochs, sess, + x_train, x_val, y_train, y_val, noInit=False, + redirFile=None, printStep=10, valStep=3): + ''' + Performs dense training of ProtoNN followed by iterative hard + thresholding to enforce sparsity constraints. + + batchSize: Batch size per update + totalEpochs: The number of epochs to run training for. One epoch is + defined as one pass over the entire training data. + sess: The Tensorflow session to use for running various graph + operators. + x_train, x_val, y_train, y_val: The numpy array containing train and + validation data. x data is assumed to in of shape [-1, + featureDimension] while y should have shape [-1, numberLabels]. + noInit: By default, all the tensors of the computation graph are + initialized at the start of the training session. Set noInit=False to + disable this behaviour. + printStep: Number of batches between echoing of loss and train accuracy. + valStep: Number of epochs between evolutions on validation set. + ''' + d, d_cap, m, L, gamma = self.protoNNObj.getHyperParams() + assert batchSize >= 1, 'Batch size should be positive integer' + assert totalEpochs >= 1, 'Total epochs should be positive integer' + assert x_train.ndim == 2, 'Expected training data to be of rank 2' + assert x_train.shape[1] == d, 'Expected x_train to be [-1, %d]' % d + assert x_val.ndim == 2, 'Expected validation data to be of rank 2' + assert x_val.shape[1] == d, 'Expected x_val to be [-1, %d]' % d + assert y_train.ndim == 2, 'Expected training labels to be of rank 2' + assert y_train.shape[1] == L, 'Expected y_train to be [-1, %d]' % L + assert y_val.ndim == 2, 'Expected validation labels to be of rank 2' + assert y_val.shape[1] == L, 'Expected y_val to be [-1, %d]' % L + + # Numpy will throw asserts for arrays + if sess is None: + raise ValueError('sess must be valid Tensorflow session.') + + trainNumBatches = int(np.ceil(len(x_train) / batchSize)) + valNumBatches = int(np.ceil(len(x_val) / batchSize)) + x_train_batches = np.array_split(x_train, trainNumBatches) + y_train_batches = np.array_split(y_train, trainNumBatches) + x_val_batches = np.array_split(x_val, valNumBatches) + y_val_batches = np.array_split(y_val, valNumBatches) + if not noInit: + sess.run(tf.compat.v1.global_variables_initializer()) + X, Y = self.X, self.Y + W, B, Z, _ = self.protoNNObj.getModelMatrices() + for epoch in range(totalEpochs): + for i in range(len(x_train_batches)): + batch_x = x_train_batches[i] + batch_y = y_train_batches[i] + feed_dict = { + X: batch_x, + Y: batch_y + } + sess.run(self.trainStep, feed_dict=feed_dict) + if i % printStep == 0: + loss, acc = sess.run([self.loss, self.accuracy], + feed_dict=feed_dict) + msg = "Epoch: %3d Batch: %3d" % (epoch, i) + msg += " Loss: %3.5f Accuracy: %2.5f" % (loss, acc) + print(msg, file=redirFile) + + # Perform Hard thresholding + if self.sparseTraining: + W_, B_, Z_ = sess.run([W, B, Z]) + fd_thrsd = { + self.W_th: utils.hardThreshold(W_, self.__sW), + self.B_th: utils.hardThreshold(B_, self.__sB), + self.Z_th: utils.hardThreshold(Z_, self.__sZ) + } + sess.run(self.__hthOp, feed_dict=fd_thrsd) + + if (epoch + 1) % valStep == 0: + acc = 0.0 + loss = 0.0 + for j in range(len(x_val_batches)): + batch_x = x_val_batches[j] + batch_y = y_val_batches[j] + feed_dict = { + X: batch_x, + Y: batch_y + } + acc_, loss_ = sess.run([self.accuracy, self.loss], + feed_dict=feed_dict) + acc += acc_ + loss += loss_ + acc /= len(y_val_batches) + loss /= len(y_val_batches) + print("Test Loss: %2.5f Accuracy: %2.5f" % (loss, acc)) + diff --git a/tf2.0/edgeml/utils.py b/tf2.0/edgeml/utils.py new file mode 100644 index 000000000..b3ff5adb4 --- /dev/null +++ b/tf2.0/edgeml/utils.py @@ -0,0 +1,339 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import tensorflow as tf +import numpy as np +import scipy.cluster +import scipy.spatial +import os + + +def medianHeuristic(data, projectionDimension, numPrototypes, W_init=None): + ''' + This method can be used to estimate gamma for ProtoNN. An approximation to + median heuristic is used here. + 1. First the data is collapsed into the projectionDimension by W_init. If + W_init is not provided, it is initialized from a random normal(0, 1). Hence + data normalization is essential. + 2. Prototype are computed by running a k-means clustering on the projected + data. + 3. The median distance is then estimated by calculating median distance + between prototypes and projected data points. + + data needs to be [-1, numFeats] + If using this method to initialize gamma, please use the W and B as well. + + TODO: Return estimate of Z (prototype labels) based on cluster centroids + andand labels + + TODO: Clustering fails due to singularity error if projecting upwards + + W [dxd_cap] + B [d_cap, m] + returns gamma, W, B + ''' + assert data.ndim == 2 + X = data + featDim = data.shape[1] + if projectionDimension > featDim: + print("Warning: Projection dimension > feature dimension. Gamma") + print("\t estimation due to median heuristic could fail.") + print("\tTo retain the projection dataDimension, provide") + print("\ta value for gamma.") + + if W_init is None: + W_init = np.random.normal(size=[featDim, projectionDimension]) + W = W_init + XW = np.matmul(X, W) + assert XW.shape[1] == projectionDimension + assert XW.shape[0] == len(X) + # Requires [N x d_cap] data matrix of N observations of d_cap-dimension and + # the number of centroids m. Returns, [n x d_cap] centroids and + # elementwise center information. + B, centers = scipy.cluster.vq.kmeans2(XW, numPrototypes) + # Requires two matrices. Number of observations x dimension of observation + # space. Distances[i,j] is the distance between XW[i] and B[j] + distances = scipy.spatial.distance.cdist(XW, B, metric='euclidean') + distances = np.reshape(distances, [-1]) + gamma = np.median(distances) + gamma = 1 / (2.5 * gamma) + return gamma.astype('float32'), W.astype('float32'), B.T.astype('float32') + + +def multiClassHingeLoss(logits, label, batch_th): + ''' + MultiClassHingeLoss to match C++ Version - No TF internal version + ''' + flatLogits = tf.reshape(logits, [-1, ]) + label_ = tf.argmax(input=label, axis=1) + + correctId = tf.range(0, batch_th) * label.shape[1] + label_ + correctLogit = tf.gather(flatLogits, correctId) + + maxLabel = tf.argmax(input=logits, axis=1) + top2, _ = tf.nn.top_k(logits, k=2, sorted=True) + + wrongMaxLogit = tf.where( + tf.equal(maxLabel, label_), top2[:, 1], top2[:, 0]) + + return tf.reduce_mean(input_tensor=tf.nn.relu(1. + wrongMaxLogit - correctLogit)) + + +def crossEntropyLoss(logits, label): + ''' + Cross Entropy loss for MultiClass case in joint training for + faster convergence + ''' + return tf.reduce_mean( + input_tensor=tf.nn.softmax_cross_entropy_with_logits(logits=logits, + labels=tf.stop_gradient(label))) + + +def mean_absolute_error(logits, label): + ''' + Function to compute the mean absolute error. + ''' + return tf.reduce_mean(input_tensor=tf.abs(tf.subtract(logits, label))) + + +def hardThreshold(A, s): + ''' + Hard thresholding function on Tensor A with sparsity s + ''' + A_ = np.copy(A) + A_ = A_.ravel() + if len(A_) > 0: + th = np.percentile(np.abs(A_), (1 - s) * 100.0, interpolation='higher') + A_[np.abs(A_) < th] = 0.0 + A_ = A_.reshape(A.shape) + return A_ + + +def copySupport(src, dest): + ''' + copy support of src tensor to dest tensor + ''' + support = np.nonzero(src) + dest_ = dest + dest = np.zeros(dest_.shape) + dest[support] = dest_[support] + return dest + + +def countnnZ(A, s, bytesPerVar=4): + ''' + Returns # of non-zeros and representative size of the tensor + Uses dense for s >= 0.5 - 4 byte + Else uses sparse - 8 byte + ''' + params = 1 + hasSparse = False + for i in range(0, len(A.shape)): + params *= int(A.shape[i]) + if s < 0.5: + nnZ = np.ceil(params * s) + hasSparse = True + return nnZ, nnZ * 2 * bytesPerVar, hasSparse + else: + nnZ = params + return nnZ, nnZ * bytesPerVar, hasSparse + + +def getConfusionMatrix(predicted, target, numClasses): + ''' + Returns a confusion matrix for a multiclass classification + problem. `predicted` is a 1-D array of integers representing + the predicted classes and `target` is the target classes. + + confusion[i][j]: Number of elements of class j + predicted as class i + Labels are assumed to be in range(0, numClasses) + Use`printFormattedConfusionMatrix` to echo the confusion matrix + in a user friendly form. + ''' + assert(predicted.ndim == 1) + assert(target.ndim == 1) + arr = np.zeros([numClasses, numClasses]) + + for i in range(len(predicted)): + arr[predicted[i]][target[i]] += 1 + return arr + + +def printFormattedConfusionMatrix(matrix): + ''' + Given a 2D confusion matrix, prints it in a human readable way. + The confusion matrix is expected to be a 2D numpy array with + square dimensions + ''' + assert(matrix.ndim == 2) + assert(matrix.shape[0] == matrix.shape[1]) + RECALL = 'Recall' + PRECISION = 'PRECISION' + print("|%s|" % ('True->'), end='') + for i in range(matrix.shape[0]): + print("%7d|" % i, end='') + print("%s|" % 'Precision') + + print("|%s|" % ('-' * len(RECALL)), end='') + for i in range(matrix.shape[0]): + print("%s|" % ('-' * 7), end='') + print("%s|" % ('-' * len(PRECISION))) + + precisionlist = np.sum(matrix, axis=1) + recalllist = np.sum(matrix, axis=0) + precisionlist = [matrix[i][i] / x if x != + 0 else -1 for i, x in enumerate(precisionlist)] + recalllist = [matrix[i][i] / x if x != + 0 else -1 for i, x in enumerate(recalllist)] + for i in range(matrix.shape[0]): + # len recall = 6 + print("|%6d|" % (i), end='') + for j in range(matrix.shape[0]): + print("%7d|" % (matrix[i][j]), end='') + print("%s" % (" " * (len(PRECISION) - 7)), end='') + if precisionlist[i] != -1: + print("%1.5f|" % precisionlist[i]) + else: + print("%7s|" % "nan") + + print("|%s|" % ('-' * len(RECALL)), end='') + for i in range(matrix.shape[0]): + print("%s|" % ('-' * 7), end='') + print("%s|" % ('-' * len(PRECISION))) + print("|%s|" % ('Recall'), end='') + + for i in range(matrix.shape[0]): + if recalllist[i] != -1: + print("%1.5f|" % (recalllist[i]), end='') + else: + print("%7s|" % "nan", end='') + + print('%s|' % (' ' * len(PRECISION))) + + +def getPrecisionRecall(cmatrix, label=1): + trueP = cmatrix[label][label] + denom = np.sum(cmatrix, axis=0)[label] + if denom == 0: + denom = 1 + recall = trueP / denom + denom = np.sum(cmatrix, axis=1)[label] + if denom == 0: + denom = 1 + precision = trueP / denom + return precision, recall + + +def getMacroPrecisionRecall(cmatrix): + # TP + FP + precisionlist = np.sum(cmatrix, axis=1) + # TP + FN + recalllist = np.sum(cmatrix, axis=0) + precisionlist__ = [cmatrix[i][i] / x if x != + 0 else 0 for i, x in enumerate(precisionlist)] + recalllist__ = [cmatrix[i][i] / x if x != + 0 else 0 for i, x in enumerate(recalllist)] + precision = np.sum(precisionlist__) + precision /= len(precisionlist__) + recall = np.sum(recalllist__) + recall /= len(recalllist__) + return precision, recall + + +def getMicroPrecisionRecall(cmatrix): + # TP + FP + precisionlist = np.sum(cmatrix, axis=1) + # TP + FN + recalllist = np.sum(cmatrix, axis=0) + num = 0.0 + for i in range(len(cmatrix)): + num += cmatrix[i][i] + + precision = num / np.sum(precisionlist) + recall = num / np.sum(recalllist) + return precision, recall + + +def getMacroMicroFScore(cmatrix): + ''' + Returns macro and micro f-scores. + Refer: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.104.8244&rep=rep1&type=pdf + ''' + precisionlist = np.sum(cmatrix, axis=1) + recalllist = np.sum(cmatrix, axis=0) + precisionlist__ = [cmatrix[i][i] / x if x != + 0 else 0 for i, x in enumerate(precisionlist)] + recalllist__ = [cmatrix[i][i] / x if x != + 0 else 0 for i, x in enumerate(recalllist)] + macro = 0.0 + for i in range(len(precisionlist)): + denom = precisionlist__[i] + recalllist__[i] + numer = precisionlist__[i] * recalllist__[i] * 2 + if denom == 0: + denom = 1 + macro += numer / denom + macro /= len(precisionlist) + + num = 0.0 + for i in range(len(precisionlist)): + num += cmatrix[i][i] + + denom1 = np.sum(precisionlist) + denom2 = np.sum(recalllist) + pi = num / denom1 + rho = num / denom2 + denom = pi + rho + if denom == 0: + denom = 1 + micro = 2 * pi * rho / denom + return macro, micro + + +def restructreMatrixBonsaiSeeDot(A, nClasses, nNodes): + ''' + Restructures a matrix from [nNodes*nClasses, Proj] to + [nClasses*nNodes, Proj] for SeeDot + ''' + tempMatrix = np.zeros(A.shape) + rowIndex = 0 + + for i in range(0, nClasses): + for j in range(0, nNodes): + tempMatrix[rowIndex] = A[j * nClasses + i] + rowIndex += 1 + + return tempMatrix + + +class GraphManager: + ''' + Manages saving and restoring graphs. Designed to be used with EMI-RNN + though is general enough to be useful otherwise as well. + ''' + + def __init__(self): + pass + + def checkpointModel(self, saver, sess, modelPrefix, + globalStep=1000, redirFile=None): + saver.save(sess, modelPrefix, global_step=globalStep) + print('Model saved to %s, global_step %d' % (modelPrefix, globalStep), + file=redirFile) + + def loadCheckpoint(self, sess, modelPrefix, globalStep, + redirFile=None): + metaname = modelPrefix + '-%d.meta' % globalStep + basename = os.path.basename(metaname) + fileList = os.listdir(os.path.dirname(modelPrefix)) + fileList = [x for x in fileList if x.startswith(basename)] + assert len(fileList) > 0, 'Checkpoint file not found' + msg = 'Too many or too few checkpoint files for globalStep: %d' % globalStep + assert len(fileList) is 1, msg + chkpt = basename + '/' + fileList[0] + saver = tf.compat.v1.train.import_meta_graph(metaname) + metaname = metaname[:-5] + saver.restore(sess, metaname) + graph = tf.compat.v1.get_default_graph() + return graph diff --git a/tf2.0/examples/Bonsai/README.md b/tf2.0/examples/Bonsai/README.md new file mode 100644 index 000000000..91cb00213 --- /dev/null +++ b/tf2.0/examples/Bonsai/README.md @@ -0,0 +1,67 @@ +# EdgeML Bonsai on a sample public dataset + +This directory includes, example notebook and general execution script of +Bonsai developed as part of EdgeML. Also, we include a sample cleanup and +use-case on the USPS10 public dataset. + +`edgeml.graph.bonsai` implements the Bonsai prediction graph in tensorflow. +The three-phase training routine for Bonsai is decoupled from the forward graph +to facilitate a plug and play behaviour wherein Bonsai can be combined with or +used as a final layer classifier for other architectures (RNNs, CNNs). + +Note that `bonsai_example.py` assumes that data is in a specific format. It is +assumed that train and test data is contained in two files, `train.npy` and +`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples, +numberOfFeatures + 1]`. The first column of each matrix is assumed to contain +label information. For an N-Class problem, we assume the labels are integers +from 0 through N-1. `bonsai_example.py` also supports univariate regression +and can be accessed using the help options of the script. Multivariate regression +requires restructuring of the input data format and can further help in extending +bonsai to multi-label classification and multi-variate regression. Lastly, +the training data, `train.npy`, is assumed to well shuffled +as the training routine doesn't shuffle internally. + +**Tested With:** Tensorflow >1.6 with Python 2 and Python 3 + +## Download and clean up sample dataset + +We will be testing out the validation of the code by using the USPS dataset. +The download and cleanup of the dataset to match the above-mentioned format is +done by the script [fetch_usps.py](fetch_usps.py) and +[process_usps.py](process_usps.py) + +``` +python fetch_usps.py +python process_usps.py +``` + +## Sample command for Bonsai on USPS10 +The following sample run on usps10 should validate your library: + +```bash +python bonsai_example.py -dir usps10/ -d 3 -p 28 -rW 0.001 -rZ 0.0001 -rV 0.001 -rT 0.001 -sZ 0.2 -sW 0.3 -sV 0.3 -sT 0.62 -e 100 -s 1 +``` +This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches): +``` +Maximum Test accuracy at compressed model size(including early stopping): 0.94369704 at Epoch: 66 +Final Test Accuracy: 0.93024415 + +Non-Zeros: 4156.0 Model Size: 31.703125 KB hasSparse: True +``` + +usps10 directory will now have a consolidated results file called `TFBonsaiResults.txt` and a directory `TFBonsaiResults` with the corresponding models with each run of the code on the usps10 dataset + +## Byte Quantization (Q) for model compression +If you wish to quantize the generated model to use byte quantized integers use `quantizeBonsaiModels.py`. Usage Instructions: + +``` +python quantizeBonsaiModels.py -h +``` + +This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedTFBonsaiModel` inside the model directory. +One can use this model further on edge devices. + + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT license. diff --git a/tf2.0/examples/Bonsai/bonsai_example.ipynb b/tf2.0/examples/Bonsai/bonsai_example.ipynb new file mode 100644 index 000000000..1935fd2b9 --- /dev/null +++ b/tf2.0/examples/Bonsai/bonsai_example.ipynb @@ -0,0 +1,1135 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonsai in Tensorflow\n", + "\n", + "This is a simple notebook that illustrates the usage of Tensorflow implementation of Bonsai. We are using the USPS dataset. Please refer to `fetch_usps.py` and run it for downloading and cleaning up the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.056404Z", + "start_time": "2018-08-15T12:06:05.112969Z" + } + }, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT license.\n", + "\n", + "import helpermethods\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "import sys\n", + "import os\n", + "\n", + "#Provide the GPU number to be used\n", + "os.environ['CUDA_VISIBLE_DEVICES'] =''\n", + "\n", + "#Bonsai imports\n", + "from edgeml.trainer.bonsaiTrainer import BonsaiTrainer\n", + "from edgeml.graph.bonsai import Bonsai\n", + "\n", + "# Fixing seeds for reproducibility\n", + "tf.set_random_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# USPS Data\n", + "\n", + "It is assumed that the USPS data has already been downloaded and set up with the help of [fetch_usps.py](fetch_usps.py) and is present in the `./usps10` subdirectory." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.104645Z", + "start_time": "2018-08-15T12:06:06.058368Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Feature Dimension: 257\n", + "Num classes: 10\n" + ] + } + ], + "source": [ + "#Loading and Pre-processing dataset for Bonsai\n", + "dataDir = \"usps10/\"\n", + "(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std) = helpermethods.preProcessData(dataDir, isRegression=False)\n", + "print(\"Feature Dimension: \", dataDimension)\n", + "print(\"Num classes: \", numClasses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Parameters\n", + "\n", + "Note that Bonsai is designed for low-memory setting and the best results are obtained when operating in that setting. Use the sparsity, projection dimension and tree depth to vary the model size." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.123318Z", + "start_time": "2018-08-15T12:06:06.106847Z" + } + }, + "outputs": [], + "source": [ + "sigma = 1.0 #Sigmoid parameter for tanh\n", + "depth = 3 #Depth of Bonsai Tree\n", + "projectionDimension = 28 #Lower Dimensional space for Bonsai to work on\n", + "\n", + "#Regularizers for Bonsai Parameters\n", + "regZ = 0.0001\n", + "regW = 0.001\n", + "regV = 0.001\n", + "regT = 0.001\n", + "\n", + "totalEpochs = 100\n", + "\n", + "learningRate = 0.01\n", + "\n", + "outFile = None\n", + "\n", + "#Sparsity for Bonsai Parameters. x => 100*x % are non-zeros\n", + "sparZ = 0.2\n", + "sparW = 0.3\n", + "sparV = 0.3\n", + "sparT = 0.62\n", + "\n", + "batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0]))))\n", + "\n", + "useMCHLoss = True #only for Multiclass cases True: Multiclass-Hing Loss, False: Cross Entropy. \n", + "\n", + "#Bonsai uses one classier for Binary, thus this condition\n", + "if numClasses == 2:\n", + " numClasses = 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Placeholders for Data feeding during training and infernece" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.220274Z", + "start_time": "2018-08-15T12:06:06.125219Z" + } + }, + "outputs": [], + "source": [ + "X = tf.placeholder(\"float32\", [None, dataDimension])\n", + "Y = tf.placeholder(\"float32\", [None, numClasses])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating a directory for current model in the datadirectory using timestamp" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.264985Z", + "start_time": "2018-08-15T12:06:06.222170Z" + } + }, + "outputs": [], + "source": [ + "currDir = helpermethods.createTimeStampDir(dataDir)\n", + "helpermethods.dumpCommand(sys.argv, currDir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonsai Graph Object\n", + "\n", + "Instantiating the Bonsai Graph which will be used for training and inference." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:06.341168Z", + "start_time": "2018-08-15T12:06:06.266877Z" + } + }, + "outputs": [], + "source": [ + "bonsaiObj = Bonsai(numClasses, dataDimension, projectionDimension, depth, sigma)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonsai Trainer Object\n", + "\n", + "Instantiating the Bonsai Trainer which will be used for 3 phase training." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:07.973584Z", + "start_time": "2018-08-15T12:06:06.342945Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\t-vekusu\\AppData\\Local\\Continuum\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tensorflow\\python\\ops\\gradients_impl.py:100: UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. This may consume a large amount of memory.\n", + " \"Converting sparse IndexedSlices to a dense Tensor of unknown shape. \"\n" + ] + } + ], + "source": [ + "bonsaiTrainer = BonsaiTrainer(bonsaiObj, regW, regT, regV, regZ, sparW, sparT, sparV, sparZ,\n", + " learningRate, X, Y, useMCHLoss, outFile)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Session declaration and variable initialization. \n", + "Interactive Session doesn't clog the entire GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:06:15.577425Z", + "start_time": "2018-08-15T12:06:07.976090Z" + } + }, + "outputs": [], + "source": [ + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Bonsai Training Routine\n", + "\n", + "The method to to run the 3 phase training, followed by giving out the best early stopping model, accuracy along with saving of the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T12:07:02.500241Z", + "start_time": "2018-08-15T12:06:15.579618Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 0\n", + "\n", + "******************** Dense Training Phase Started ********************\n", + "\n", + "\n", + "Classification Train Loss: 6.388934433460236\n", + "Training accuracy (Classification): 0.6250000005174015\n", + "Test accuracy 0.726956\n", + "MarginLoss + RegLoss: 1.4466879 + 3.6487768 = 5.0954647\n", + "\n", + "\n", + "Epoch Number: 1\n", + "\n", + "Classification Train Loss: 3.6885906954606376\n", + "Training accuracy (Classification): 0.8623611107468605\n", + "Test accuracy 0.758346\n", + "MarginLoss + RegLoss: 1.0173264 + 2.778634 = 3.7959604\n", + "\n", + "\n", + "Epoch Number: 2\n", + "\n", + "Classification Train Loss: 2.667721450328827\n", + "Training accuracy (Classification): 0.9184722271230485\n", + "Test accuracy 0.7429\n", + "MarginLoss + RegLoss: 0.92546654 + 2.095467 = 3.0209336\n", + "\n", + "\n", + "Epoch Number: 3\n", + "\n", + "Classification Train Loss: 1.9921080254846149\n", + "Training accuracy (Classification): 0.941944446000788\n", + "Test accuracy 0.767314\n", + "MarginLoss + RegLoss: 0.7603649 + 1.5889603 = 2.3493252\n", + "\n", + "\n", + "Epoch Number: 4\n", + "\n", + "Classification Train Loss: 1.5233625107341342\n", + "Training accuracy (Classification): 0.9563888907432556\n", + "Test accuracy 0.791231\n", + "MarginLoss + RegLoss: 0.6496898 + 1.2271981 = 1.8768879\n", + "\n", + "\n", + "Epoch Number: 5\n", + "\n", + "Classification Train Loss: 1.1950715631246567\n", + "Training accuracy (Classification): 0.9650000035762787\n", + "Test accuracy 0.810164\n", + "MarginLoss + RegLoss: 0.54003507 + 0.97295314 = 1.5129882\n", + "\n", + "\n", + "Epoch Number: 6\n", + "\n", + "Classification Train Loss: 0.9672323316335678\n", + "Training accuracy (Classification): 0.968333340353436\n", + "Test accuracy 0.855007\n", + "MarginLoss + RegLoss: 0.44149697 + 0.79325426 = 1.2347512\n", + "\n", + "\n", + "Epoch Number: 7\n", + "\n", + "Classification Train Loss: 0.8014380658666292\n", + "Training accuracy (Classification): 0.9722222313284874\n", + "Test accuracy 0.874938\n", + "MarginLoss + RegLoss: 0.37062877 + 0.6628879 = 1.0335166\n", + "\n", + "\n", + "Epoch Number: 8\n", + "\n", + "Classification Train Loss: 0.684503066043059\n", + "Training accuracy (Classification): 0.976111119820012\n", + "Test accuracy 0.899851\n", + "MarginLoss + RegLoss: 0.3099702 + 0.5688073 = 0.8787775\n", + "\n", + "\n", + "Epoch Number: 9\n", + "\n", + "Classification Train Loss: 0.5987317487597466\n", + "Training accuracy (Classification): 0.9794444565971693\n", + "Test accuracy 0.907324\n", + "MarginLoss + RegLoss: 0.2689218 + 0.49965328 = 0.7685751\n", + "\n", + "\n", + "Epoch Number: 10\n", + "\n", + "Classification Train Loss: 0.5343128165437115\n", + "Training accuracy (Classification): 0.9804166778922081\n", + "Test accuracy 0.9143\n", + "MarginLoss + RegLoss: 0.24538836 + 0.44663915 = 0.6920275\n", + "\n", + "\n", + "Epoch Number: 11\n", + "\n", + "Classification Train Loss: 0.48874612069792217\n", + "Training accuracy (Classification): 0.9801388987236552\n", + "Test accuracy 0.916293\n", + "MarginLoss + RegLoss: 0.23703864 + 0.40629783 = 0.6433365\n", + "\n", + "\n", + "Epoch Number: 12\n", + "\n", + "Classification Train Loss: 0.44733552055226433\n", + "Training accuracy (Classification): 0.98097223126226\n", + "Test accuracy 0.918286\n", + "MarginLoss + RegLoss: 0.23851919 + 0.37269312 = 0.6112123\n", + "\n", + "\n", + "Epoch Number: 13\n", + "\n", + "Classification Train Loss: 0.4165669356783231\n", + "Training accuracy (Classification): 0.9822222317258517\n", + "Test accuracy 0.917289\n", + "MarginLoss + RegLoss: 0.23061273 + 0.345445 = 0.57605773\n", + "\n", + "\n", + "Epoch Number: 14\n", + "\n", + "Classification Train Loss: 0.39181090601616436\n", + "Training accuracy (Classification): 0.9812500087751282\n", + "Test accuracy 0.92277\n", + "MarginLoss + RegLoss: 0.2121576 + 0.32245666 = 0.53461426\n", + "\n", + "\n", + "Epoch Number: 15\n", + "\n", + "Classification Train Loss: 0.36949437111616135\n", + "Training accuracy (Classification): 0.9820833446251022\n", + "Test accuracy 0.926258\n", + "MarginLoss + RegLoss: 0.19854721 + 0.30341443 = 0.50196165\n", + "\n", + "\n", + "Epoch Number: 16\n", + "\n", + "Classification Train Loss: 0.3469446731938256\n", + "Training accuracy (Classification): 0.9831944538487328\n", + "Test accuracy 0.927255\n", + "MarginLoss + RegLoss: 0.19628116 + 0.28535655 = 0.48163772\n", + "\n", + "\n", + "Epoch Number: 17\n", + "\n", + "Classification Train Loss: 0.329777576857143\n", + "Training accuracy (Classification): 0.984166675971614\n", + "Test accuracy 0.92277\n", + "MarginLoss + RegLoss: 0.20166817 + 0.26965213 = 0.4713203\n", + "\n", + "\n", + "Epoch Number: 18\n", + "\n", + "Classification Train Loss: 0.317672994815641\n", + "Training accuracy (Classification): 0.9815277879436811\n", + "Test accuracy 0.925262\n", + "MarginLoss + RegLoss: 0.20086277 + 0.2559616 = 0.45682436\n", + "\n", + "\n", + "Epoch Number: 19\n", + "\n", + "Classification Train Loss: 0.3000084459781647\n", + "Training accuracy (Classification): 0.9843055655558904\n", + "Test accuracy 0.931739\n", + "MarginLoss + RegLoss: 0.18073215 + 0.24324338 = 0.42397553\n", + "\n", + "\n", + "Epoch Number: 20\n", + "\n", + "Classification Train Loss: 0.2897499371320009\n", + "Training accuracy (Classification): 0.9827777867515882\n", + "Test accuracy 0.921276\n", + "MarginLoss + RegLoss: 0.20172484 + 0.23221089 = 0.43393573\n", + "\n", + "\n", + "Epoch Number: 21\n", + "\n", + "Classification Train Loss: 0.2821065636558665\n", + "Training accuracy (Classification): 0.9812500096029706\n", + "Test accuracy 0.928749\n", + "MarginLoss + RegLoss: 0.18990344 + 0.22147894 = 0.41138238\n", + "\n", + "\n", + "Epoch Number: 22\n", + "\n", + "Classification Train Loss: 0.2660716378854381\n", + "Training accuracy (Classification): 0.9844444559680091\n", + "Test accuracy 0.928251\n", + "MarginLoss + RegLoss: 0.17955597 + 0.21111046 = 0.39066643\n", + "\n", + "\n", + "Epoch Number: 23\n", + "\n", + "Classification Train Loss: 0.2567368100086848\n", + "Training accuracy (Classification): 0.9852777885066138\n", + "Test accuracy 0.928251\n", + "MarginLoss + RegLoss: 0.18770447 + 0.20248988 = 0.39019436\n", + "\n", + "\n", + "Epoch Number: 24\n", + "\n", + "Classification Train Loss: 0.25224825532899964\n", + "Training accuracy (Classification): 0.9823611204822859\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.18552671 + 0.19460817 = 0.38013488\n", + "\n", + "\n", + "Epoch Number: 25\n", + "\n", + "Classification Train Loss: 0.24661735258996487\n", + "Training accuracy (Classification): 0.9804166762365235\n", + "Test accuracy 0.931241\n", + "MarginLoss + RegLoss: 0.18796808 + 0.18610859 = 0.37407666\n", + "\n", + "\n", + "Epoch Number: 26\n", + "\n", + "Classification Train Loss: 0.23342499737110403\n", + "Training accuracy (Classification): 0.9829166763358645\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.17906994 + 0.17793566 = 0.3570056\n", + "\n", + "\n", + "Epoch Number: 27\n", + "\n", + "Classification Train Loss: 0.22210048822065195\n", + "Training accuracy (Classification): 0.9851388972666528\n", + "Test accuracy 0.934728\n", + "MarginLoss + RegLoss: 0.17679122 + 0.16876754 = 0.34555876\n", + "\n", + "\n", + "Epoch Number: 28\n", + "\n", + "Classification Train Loss: 0.2189549288402001\n", + "Training accuracy (Classification): 0.9831944538487328\n", + "Test accuracy 0.932237\n", + "MarginLoss + RegLoss: 0.19115414 + 0.16296963 = 0.35412377\n", + "\n", + "\n", + "Epoch Number: 29\n", + "\n", + "Classification Train Loss: 0.21842483865718046\n", + "Training accuracy (Classification): 0.9805555658208\n", + "Test accuracy 0.936722\n", + "MarginLoss + RegLoss: 0.17462157 + 0.15921564 = 0.3338372\n", + "\n", + "\n", + "Epoch Number: 30\n", + "\n", + "Classification Train Loss: 0.21449942576388517\n", + "Training accuracy (Classification): 0.9804166754086813\n", + "Test accuracy 0.939711\n", + "MarginLoss + RegLoss: 0.17741902 + 0.15273981 = 0.33015883\n", + "\n", + "\n", + "Epoch Number: 31\n", + "\n", + "Classification Train Loss: 0.20739994280868107\n", + "Training accuracy (Classification): 0.9825000100665622\n", + "Test accuracy 0.933732\n", + "MarginLoss + RegLoss: 0.17381513 + 0.1498537 = 0.32366884\n", + "\n", + "\n", + "Epoch Number: 32\n", + "\n", + "Classification Train Loss: 0.20110303929282558\n", + "Training accuracy (Classification): 0.9840277888708644\n", + "Test accuracy 0.93423\n", + "MarginLoss + RegLoss: 0.18619148 + 0.14583017 = 0.33202165\n", + "\n", + "\n", + "Epoch Number: 33\n", + "\n", + "******************** IHT Phase Started ********************\n", + "\n", + "\n", + "Classification Train Loss: 0.21433907147083017\n", + "Training accuracy (Classification): 0.9801388987236552\n", + "Test accuracy 0.927255\n", + "MarginLoss + RegLoss: 0.19979775 + 0.12088289 = 0.32068065\n", + "\n", + "\n", + "Epoch Number: 34\n", + "\n", + "Classification Train Loss: 0.1990115779141585\n", + "Training accuracy (Classification): 0.980694454577234\n", + "Test accuracy 0.933234\n", + "MarginLoss + RegLoss: 0.17835513 + 0.12438774 = 0.30274287\n", + "\n", + "\n", + "Epoch Number: 35\n", + "\n", + "Classification Train Loss: 0.20429682172834873\n", + "Training accuracy (Classification): 0.9788888974322213\n", + "Test accuracy 0.929248\n", + "MarginLoss + RegLoss: 0.19013074 + 0.12853864 = 0.31866938\n", + "\n", + "\n", + "Epoch Number: 36\n", + "\n", + "Classification Train Loss: 0.19357945707937083\n", + "Training accuracy (Classification): 0.9816666767001152\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.18534705 + 0.12509713 = 0.31044418\n", + "\n", + "\n", + "Epoch Number: 37\n", + "\n", + "Classification Train Loss: 0.18653404754069117\n", + "Training accuracy (Classification): 0.9818055638008647\n", + "Test accuracy 0.929746\n", + "MarginLoss + RegLoss: 0.18708317 + 0.12236847 = 0.30945164\n", + "\n", + "\n", + "Epoch Number: 38\n", + "\n", + "Classification Train Loss: 0.18141362298693922\n", + "Training accuracy (Classification): 0.9815277871158388\n", + "Test accuracy 0.933234\n", + "MarginLoss + RegLoss: 0.18262453 + 0.11991154 = 0.30253607\n", + "\n", + "\n", + "Epoch Number: 39\n", + "\n", + "Classification Train Loss: 0.17729416727605793\n", + "Training accuracy (Classification): 0.9820833429694176\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.1798804 + 0.11748926 = 0.29736966\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 40\n", + "\n", + "Classification Train Loss: 0.17237282171845436\n", + "Training accuracy (Classification): 0.9837500088744693\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.17473482 + 0.11479883 = 0.28953364\n", + "\n", + "\n", + "Epoch Number: 41\n", + "\n", + "Classification Train Loss: 0.16901198805620274\n", + "Training accuracy (Classification): 0.9837500097023116\n", + "Test accuracy 0.93423\n", + "MarginLoss + RegLoss: 0.17860568 + 0.112817116 = 0.29142278\n", + "\n", + "\n", + "Epoch Number: 42\n", + "\n", + "Classification Train Loss: 0.16710670509686074\n", + "Training accuracy (Classification): 0.9833333442608515\n", + "Test accuracy 0.936722\n", + "MarginLoss + RegLoss: 0.17501548 + 0.11118551 = 0.286201\n", + "\n", + "\n", + "Epoch Number: 43\n", + "\n", + "Classification Train Loss: 0.16463725310232905\n", + "Training accuracy (Classification): 0.9836111209458775\n", + "Test accuracy 0.93423\n", + "MarginLoss + RegLoss: 0.17687047 + 0.10897398 = 0.28584445\n", + "\n", + "\n", + "Epoch Number: 44\n", + "\n", + "Classification Train Loss: 0.16215091271118987\n", + "Training accuracy (Classification): 0.9843055663837327\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.17832607 + 0.107886344 = 0.2862124\n", + "\n", + "\n", + "Epoch Number: 45\n", + "\n", + "Classification Train Loss: 0.16012930932144323\n", + "Training accuracy (Classification): 0.9841666767994562\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.17309293 + 0.10644325 = 0.2795362\n", + "\n", + "\n", + "Epoch Number: 46\n", + "\n", + "Classification Train Loss: 0.1574974125251174\n", + "Training accuracy (Classification): 0.9850000101659033\n", + "Test accuracy 0.93722\n", + "MarginLoss + RegLoss: 0.17099261 + 0.10526536 = 0.27625796\n", + "\n", + "\n", + "Epoch Number: 47\n", + "\n", + "Classification Train Loss: 0.15617641361637247\n", + "Training accuracy (Classification): 0.9856944539480739\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.16866577 + 0.104043506 = 0.27270928\n", + "\n", + "\n", + "Epoch Number: 48\n", + "\n", + "Classification Train Loss: 0.15530151346077523\n", + "Training accuracy (Classification): 0.9838889001144303\n", + "Test accuracy 0.940209\n", + "MarginLoss + RegLoss: 0.16514857 + 0.10232182 = 0.2674704\n", + "\n", + "\n", + "Epoch Number: 49\n", + "\n", + "Classification Train Loss: 0.15294318615148464\n", + "Training accuracy (Classification): 0.9862500089738104\n", + "Test accuracy 0.939711\n", + "MarginLoss + RegLoss: 0.16788226 + 0.10096101 = 0.26884326\n", + "\n", + "\n", + "Epoch Number: 50\n", + "\n", + "Classification Train Loss: 0.15095406781054205\n", + "Training accuracy (Classification): 0.9861111202173762\n", + "Test accuracy 0.940209\n", + "MarginLoss + RegLoss: 0.17100953 + 0.10046519 = 0.27147472\n", + "\n", + "\n", + "Epoch Number: 51\n", + "\n", + "Classification Train Loss: 0.1513558304351237\n", + "Training accuracy (Classification): 0.9844444543123245\n", + "Test accuracy 0.941704\n", + "MarginLoss + RegLoss: 0.1662268 + 0.100100346 = 0.26632714\n", + "\n", + "\n", + "Epoch Number: 52\n", + "\n", + "Classification Train Loss: 0.14914156941490042\n", + "Training accuracy (Classification): 0.9852777876787715\n", + "Test accuracy 0.941206\n", + "MarginLoss + RegLoss: 0.16318396 + 0.099286705 = 0.26247066\n", + "\n", + "\n", + "Epoch Number: 53\n", + "\n", + "Classification Train Loss: 0.1497938595712185\n", + "Training accuracy (Classification): 0.9851388997501798\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.17166732 + 0.09957267 = 0.27124\n", + "\n", + "\n", + "Epoch Number: 54\n", + "\n", + "Classification Train Loss: 0.15218847369154295\n", + "Training accuracy (Classification): 0.985277786023087\n", + "Test accuracy 0.938715\n", + "MarginLoss + RegLoss: 0.17181182 + 0.09915227 = 0.2709641\n", + "\n", + "\n", + "Epoch Number: 55\n", + "\n", + "Classification Train Loss: 0.14960632245573732\n", + "Training accuracy (Classification): 0.9855555668473244\n", + "Test accuracy 0.943697\n", + "MarginLoss + RegLoss: 0.16333821 + 0.09872535 = 0.26206356\n", + "\n", + "\n", + "Epoch Number: 56\n", + "\n", + "Classification Train Loss: 0.15064662312053972\n", + "Training accuracy (Classification): 0.9852777885066138\n", + "Test accuracy 0.942202\n", + "MarginLoss + RegLoss: 0.16303498 + 0.09878391 = 0.2618189\n", + "\n", + "\n", + "Epoch Number: 57\n", + "\n", + "Classification Train Loss: 0.15265570394694805\n", + "Training accuracy (Classification): 0.9831944555044174\n", + "Test accuracy 0.940708\n", + "MarginLoss + RegLoss: 0.16671813 + 0.09886683 = 0.26558495\n", + "\n", + "\n", + "Epoch Number: 58\n", + "\n", + "Classification Train Loss: 0.15230748295370075\n", + "Training accuracy (Classification): 0.984166675971614\n", + "Test accuracy 0.938715\n", + "MarginLoss + RegLoss: 0.16594657 + 0.097650595 = 0.26359716\n", + "\n", + "\n", + "Epoch Number: 59\n", + "\n", + "Classification Train Loss: 0.1514456778143843\n", + "Training accuracy (Classification): 0.9843055647280481\n", + "Test accuracy 0.938216\n", + "MarginLoss + RegLoss: 0.16204405 + 0.09645542 = 0.25849947\n", + "\n", + "\n", + "Epoch Number: 60\n", + "\n", + "Classification Train Loss: 0.15362831794967255\n", + "Training accuracy (Classification): 0.9829166771637069\n", + "Test accuracy 0.933732\n", + "MarginLoss + RegLoss: 0.17626402 + 0.09787459 = 0.2741386\n", + "\n", + "\n", + "Epoch Number: 61\n", + "\n", + "Classification Train Loss: 0.15526858448154396\n", + "Training accuracy (Classification): 0.9813889024986161\n", + "Test accuracy 0.933732\n", + "MarginLoss + RegLoss: 0.17297557 + 0.09806729 = 0.27104285\n", + "\n", + "\n", + "Epoch Number: 62\n", + "\n", + "Classification Train Loss: 0.1579084157322844\n", + "Training accuracy (Classification): 0.9816666767001152\n", + "Test accuracy 0.936223\n", + "MarginLoss + RegLoss: 0.17195764 + 0.098572396 = 0.27053005\n", + "\n", + "\n", + "Epoch Number: 63\n", + "\n", + "Classification Train Loss: 0.1566090847675999\n", + "Training accuracy (Classification): 0.9826389013065232\n", + "Test accuracy 0.93423\n", + "MarginLoss + RegLoss: 0.17155647 + 0.10033124 = 0.27188772\n", + "\n", + "\n", + "Epoch Number: 64\n", + "\n", + "Classification Train Loss: 0.1548497351921267\n", + "Training accuracy (Classification): 0.9837500105301539\n", + "Test accuracy 0.941704\n", + "MarginLoss + RegLoss: 0.16137016 + 0.099378176 = 0.26074833\n", + "\n", + "\n", + "Epoch Number: 65\n", + "\n", + "Classification Train Loss: 0.15319975931197405\n", + "Training accuracy (Classification): 0.9829166746801801\n", + "Test accuracy 0.939213\n", + "MarginLoss + RegLoss: 0.16549328 + 0.09872568 = 0.26421896\n", + "\n", + "\n", + "Epoch Number: 66\n", + "\n", + "Classification Train Loss: 0.1565150058724814\n", + "Training accuracy (Classification): 0.9819444542129835\n", + "Test accuracy 0.935725\n", + "MarginLoss + RegLoss: 0.17288828 + 0.09988601 = 0.27277428\n", + "\n", + "\n", + "Epoch Number: 67\n", + "\n", + "******************** Sparse Retraining Phase Started ********************\n", + "\n", + "\n", + "Classification Train Loss: 0.15831943404757315\n", + "Training accuracy (Classification): 0.9829166779915491\n", + "Test accuracy 0.935725\n", + "MarginLoss + RegLoss: 0.17936754 + 0.101812266 = 0.28117982\n", + "\n", + "\n", + "Epoch Number: 68\n", + "\n", + "Classification Train Loss: 0.15614786164628136\n", + "Training accuracy (Classification): 0.9838889009422727\n", + "Test accuracy 0.931739\n", + "MarginLoss + RegLoss: 0.17960551 + 0.101831324 = 0.28143683\n", + "\n", + "\n", + "Epoch Number: 69\n", + "\n", + "Classification Train Loss: 0.1662438316270709\n", + "Training accuracy (Classification): 0.9827777884072728\n", + "Test accuracy 0.931739\n", + "MarginLoss + RegLoss: 0.19018382 + 0.10729199 = 0.2974758\n", + "\n", + "\n", + "Epoch Number: 70\n", + "\n", + "Classification Train Loss: 0.16005917576452097\n", + "Training accuracy (Classification): 0.9844444518287977\n", + "Test accuracy 0.929248\n", + "MarginLoss + RegLoss: 0.19133526 + 0.10547125 = 0.2968065\n", + "\n", + "\n", + "Epoch Number: 71\n", + "\n", + "Classification Train Loss: 0.15785305326183638\n", + "Training accuracy (Classification): 0.985000009338061\n", + "Test accuracy 0.933732\n", + "MarginLoss + RegLoss: 0.18749763 + 0.10477199 = 0.29226962\n", + "\n", + "\n", + "Epoch Number: 72\n", + "\n", + "Classification Train Loss: 0.15456503671076563\n", + "Training accuracy (Classification): 0.9843055663837327\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.1811654 + 0.10317116 = 0.28433657\n", + "\n", + "\n", + "Epoch Number: 73\n", + "\n", + "Classification Train Loss: 0.15287091862410307\n", + "Training accuracy (Classification): 0.9848611205816269\n", + "Test accuracy 0.934728\n", + "MarginLoss + RegLoss: 0.17708676 + 0.101716325 = 0.27880308\n", + "\n", + "\n", + "Epoch Number: 74\n", + "\n", + "Classification Train Loss: 0.15090375486761332\n", + "Training accuracy (Classification): 0.9855555643637975\n", + "Test accuracy 0.934728\n", + "MarginLoss + RegLoss: 0.17898533 + 0.10174509 = 0.28073043\n", + "\n", + "\n", + "Epoch Number: 75\n", + "\n", + "Classification Train Loss: 0.15054931139780414\n", + "Training accuracy (Classification): 0.9848611197537847\n", + "Test accuracy 0.93722\n", + "MarginLoss + RegLoss: 0.17272809 + 0.101017065 = 0.27374515\n", + "\n", + "\n", + "Epoch Number: 76\n", + "\n", + "Classification Train Loss: 0.14770951929191747\n", + "Training accuracy (Classification): 0.9855555651916398\n", + "Test accuracy 0.936722\n", + "MarginLoss + RegLoss: 0.17685911 + 0.09888628 = 0.2757454\n", + "\n", + "\n", + "Epoch Number: 77\n", + "\n", + "Classification Train Loss: 0.14727520239022043\n", + "Training accuracy (Classification): 0.9841666767994562\n", + "Test accuracy 0.935725\n", + "MarginLoss + RegLoss: 0.1720485 + 0.09774725 = 0.26979575\n", + "\n", + "\n", + "Epoch Number: 78\n", + "\n", + "Classification Train Loss: 0.1471475510754519\n", + "Training accuracy (Classification): 0.9858333418766657\n", + "Test accuracy 0.940209\n", + "MarginLoss + RegLoss: 0.16558117 + 0.09803399 = 0.26361516\n", + "\n", + "\n", + "Epoch Number: 79\n", + "\n", + "Classification Train Loss: 0.14565238232413927\n", + "Training accuracy (Classification): 0.9861111210452186\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.17031503 + 0.09688788 = 0.2672029\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 80\n", + "\n", + "Classification Train Loss: 0.14349345521380505\n", + "Training accuracy (Classification): 0.9861111185616918\n", + "Test accuracy 0.941206\n", + "MarginLoss + RegLoss: 0.16280341 + 0.09526416 = 0.25806758\n", + "\n", + "\n", + "Epoch Number: 81\n", + "\n", + "Classification Train Loss: 0.14298133655554718\n", + "Training accuracy (Classification): 0.9848611205816269\n", + "Test accuracy 0.935725\n", + "MarginLoss + RegLoss: 0.16992427 + 0.095204785 = 0.26512906\n", + "\n", + "\n", + "Epoch Number: 82\n", + "\n", + "Classification Train Loss: 0.1410345918395453\n", + "Training accuracy (Classification): 0.9854166756073633\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.16711517 + 0.09361006 = 0.26072523\n", + "\n", + "\n", + "Epoch Number: 83\n", + "\n", + "Classification Train Loss: 0.14173460192978382\n", + "Training accuracy (Classification): 0.9858333418766657\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.17255607 + 0.09335034 = 0.26590642\n", + "\n", + "\n", + "Epoch Number: 84\n", + "\n", + "Classification Train Loss: 0.1413275660533044\n", + "Training accuracy (Classification): 0.985000009338061\n", + "Test accuracy 0.939213\n", + "MarginLoss + RegLoss: 0.1691187 + 0.09220875 = 0.26132745\n", + "\n", + "\n", + "Epoch Number: 85\n", + "\n", + "Classification Train Loss: 0.1399904629215598\n", + "Training accuracy (Classification): 0.9863888977302445\n", + "Test accuracy 0.937718\n", + "MarginLoss + RegLoss: 0.16878359 + 0.09304918 = 0.26183277\n", + "\n", + "\n", + "Epoch Number: 86\n", + "\n", + "Classification Train Loss: 0.14306676108390093\n", + "Training accuracy (Classification): 0.9848611214094691\n", + "Test accuracy 0.933732\n", + "MarginLoss + RegLoss: 0.17234829 + 0.09307802 = 0.2654263\n", + "\n", + "\n", + "Epoch Number: 87\n", + "\n", + "Classification Train Loss: 0.14483444765210152\n", + "Training accuracy (Classification): 0.9838888976309035\n", + "Test accuracy 0.932237\n", + "MarginLoss + RegLoss: 0.17103034 + 0.093002975 = 0.26403332\n", + "\n", + "\n", + "Epoch Number: 88\n", + "\n", + "Classification Train Loss: 0.1426364007509417\n", + "Training accuracy (Classification): 0.9854166772630479\n", + "Test accuracy 0.938216\n", + "MarginLoss + RegLoss: 0.17191838 + 0.09332408 = 0.26524246\n", + "\n", + "\n", + "Epoch Number: 89\n", + "\n", + "Classification Train Loss: 0.1419605797984534\n", + "Training accuracy (Classification): 0.9854166756073633\n", + "Test accuracy 0.93722\n", + "MarginLoss + RegLoss: 0.16863512 + 0.09229554 = 0.26093066\n", + "\n", + "\n", + "Epoch Number: 90\n", + "\n", + "Classification Train Loss: 0.1416015759524372\n", + "Training accuracy (Classification): 0.984166675971614\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.17089692 + 0.0915688 = 0.26246572\n", + "\n", + "\n", + "Epoch Number: 91\n", + "\n", + "Classification Train Loss: 0.1449494053506189\n", + "Training accuracy (Classification): 0.9843055663837327\n", + "Test accuracy 0.933234\n", + "MarginLoss + RegLoss: 0.17210826 + 0.092280895 = 0.26438916\n", + "\n", + "\n", + "Epoch Number: 92\n", + "\n", + "Classification Train Loss: 0.14661915486471522\n", + "Training accuracy (Classification): 0.9826388971673118\n", + "Test accuracy 0.935725\n", + "MarginLoss + RegLoss: 0.17449446 + 0.092357084 = 0.26685154\n", + "\n", + "\n", + "Epoch Number: 93\n", + "\n", + "Classification Train Loss: 0.1467396484480964\n", + "Training accuracy (Classification): 0.9831944546765752\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.17004617 + 0.09433146 = 0.26437762\n", + "\n", + "\n", + "Epoch Number: 94\n", + "\n", + "Classification Train Loss: 0.1460545692178938\n", + "Training accuracy (Classification): 0.9841666767994562\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.17442052 + 0.09421773 = 0.26863825\n", + "\n", + "\n", + "Epoch Number: 95\n", + "\n", + "Classification Train Loss: 0.14522172489927876\n", + "Training accuracy (Classification): 0.9843055639002058\n", + "Test accuracy 0.936223\n", + "MarginLoss + RegLoss: 0.16918503 + 0.09473272 = 0.26391774\n", + "\n", + "\n", + "Epoch Number: 96\n", + "\n", + "Classification Train Loss: 0.14685245561930868\n", + "Training accuracy (Classification): 0.9838888992865881\n", + "Test accuracy 0.93423\n", + "MarginLoss + RegLoss: 0.1715351 + 0.09685955 = 0.26839465\n", + "\n", + "\n", + "Epoch Number: 97\n", + "\n", + "Classification Train Loss: 0.15079948357823822\n", + "Training accuracy (Classification): 0.9830555634366142\n", + "Test accuracy 0.935227\n", + "MarginLoss + RegLoss: 0.1724481 + 0.0967999 = 0.269248\n", + "\n", + "\n", + "Epoch Number: 98\n", + "\n", + "Classification Train Loss: 0.15230303982065785\n", + "Training accuracy (Classification): 0.9816666767001152\n", + "Test accuracy 0.932237\n", + "MarginLoss + RegLoss: 0.17799449 + 0.09676037 = 0.27475485\n", + "\n", + "\n", + "Epoch Number: 99\n", + "\n", + "Classification Train Loss: 0.1494007593848639\n", + "Training accuracy (Classification): 0.9838888976309035\n", + "Test accuracy 0.932735\n", + "MarginLoss + RegLoss: 0.17286898 + 0.096531555 = 0.26940054\n", + "\n", + "\n", + "Non-Zero : 4156.0 Model Size: 31.703125 KB hasSparse: True\n", + "\n", + "For Classification, Maximum Test accuracy at compressed model size(including early stopping): 0.94369704 at Epoch: 56\n", + "Final Test Accuracy: 0.93273544\n", + "The Model Directory: usps10//TFBonsaiResults/16_20_53_15_02_19\n", + "\n" + ] + } + ], + "source": [ + "bonsaiTrainer.train(batchSize, totalEpochs, sess,\n", + " Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tf2.0/examples/Bonsai/bonsai_example.py b/tf2.0/examples/Bonsai/bonsai_example.py new file mode 100644 index 000000000..2fc29e7c4 --- /dev/null +++ b/tf2.0/examples/Bonsai/bonsai_example.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import tensorflow as tf +import numpy as np +import sys +from edgeml.trainer.bonsaiTrainer import BonsaiTrainer +from edgeml.graph.bonsai import Bonsai + +tf.compat.v1.disable_eager_execution() + +def main(): + # Fixing seeds for reproducibility + tf.compat.v1.set_random_seed(42) + np.random.seed(42) + + # Hyper Param pre-processing + args = helpermethods.getArgs() + + # Set 'isRegression' to be True, for regression. Default is 'False'. + isRegression = args.regression + + sigma = args.sigma + depth = args.depth + + projectionDimension = args.proj_dim + regZ = args.rZ + regT = args.rT + regW = args.rW + regV = args.rV + + totalEpochs = args.epochs + + learningRate = args.learning_rate + + dataDir = args.data_dir + + outFile = args.output_file + + (dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, + mean, std) = helpermethods.preProcessData(dataDir, isRegression) + + sparZ = args.sZ + + if numClasses > 2: + sparW = 0.2 + sparV = 0.2 + sparT = 0.2 + else: + sparW = 1 + sparV = 1 + sparT = 1 + + if args.sW is not None: + sparW = args.sW + if args.sV is not None: + sparV = args.sV + if args.sT is not None: + sparT = args.sT + + if args.batch_size is None: + batchSize = np.maximum(100, int(np.ceil(np.sqrt(Ytrain.shape[0])))) + else: + batchSize = args.batch_size + + useMCHLoss = True + + if numClasses == 2: + numClasses = 1 + + X = tf.compat.v1.placeholder("float32", [None, dataDimension]) + Y = tf.compat.v1.placeholder("float32", [None, numClasses]) + + currDir = helpermethods.createTimeStampDir(dataDir) + + helpermethods.dumpCommand(sys.argv, currDir) + helpermethods.saveMeanStd(mean, std, currDir) + + # numClasses = 1 for binary case + bonsaiObj = Bonsai(numClasses, dataDimension, + projectionDimension, depth, sigma, isRegression) + + bonsaiTrainer = BonsaiTrainer(bonsaiObj, + regW, regT, regV, regZ, + sparW, sparT, sparV, sparZ, + learningRate, X, Y, useMCHLoss, outFile) + + sess = tf.compat.v1.InteractiveSession() + + sess.run(tf.compat.v1.global_variables_initializer()) + + bonsaiTrainer.train(batchSize, totalEpochs, sess, + Xtrain, Xtest, Ytrain, Ytest, dataDir, currDir) + + sess.close() + sys.stdout.close() + + +if __name__ == '__main__': + main() + +# For the following command: +# Data - Curet +# python2 bonsai_example.py -dir ./curet/ -d 2 -p 22 -rW 0.00001 -rZ 0.0000001 -rV 0.00001 -rT 0.000001 -sZ 0.4 -sW 0.5 -sV 0.5 -sT 1 -e 300 -s 0.1 -b 20 +# Final Output - useMCHLoss = True +# Maximum Test accuracy at compressed model size(including early stopping): 0.93727726 at Epoch: 297 +# Final Test Accuracy: 0.9337135 +# Non-Zeros: 24231.0 Model Size: 115.65625 KB hasSparse: True + +# Data - usps2 +# python2 bonsai_example.py -dir /mnt/c/Users/t-vekusu/Downloads/datasets/usps-binary/ -d 2 -p 22 -rW 0.00001 -rZ 0.0000001 -rV 0.00001 -rT 0.000001 -sZ 0.4 -sW 0.5 -sV 0.5 -sT 1 -e 300 -s 0.1 -b 20 +# Maximum Test accuracy at compressed model size(including early stopping): 0.9521674 at Epoch: 246 +# Final Test Accuracy: 0.94170403 +# Non-Zeros: 2636.0 Model Size: 19.1328125 KB hasSparse: True diff --git a/tf2.0/examples/Bonsai/fetch_usps.py b/tf2.0/examples/Bonsai/fetch_usps.py new file mode 100644 index 000000000..c1b2e0726 --- /dev/null +++ b/tf2.0/examples/Bonsai/fetch_usps.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Setting up the USPS Data. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def downloadData(workingDir, downloadDir, linkTrain, linkTest): + def runcommand(command): + p = subprocess.Popen(command.split(), stdout=subprocess.PIPE) + output, error = p.communicate() + assert(p.returncode == 0), 'Command failed: %s' % command + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + try: + os.mkdir(path) + except OSError: + print("Could not create %s. Make sure the path does" % path) + print("not already exist and you have permisions to create it.") + return False + cwd = os.getcwd() + os.chdir(path) + print("Downloading data") + command = 'wget %s' % linkTrain + runcommand(command) + command = 'wget %s' % linkTest + runcommand(command) + print("Extracting data") + command = 'bzip2 -d usps.bz2' + runcommand(command) + command = 'bzip2 -d usps.t.bz2' + runcommand(command) + command = 'mv usps train.txt' + runcommand(command) + command = 'mv usps.t test.txt' + runcommand(command) + os.chdir(cwd) + return True + +if __name__ == '__main__': + workingDir = './' + downloadDir = 'usps10' + linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2' + linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2' + failureMsg = ''' +Download Failed! +To manually perform the download +\t1. Create a new empty directory named `usps10`. +\t2. Download the data from the following links into the usps10 directory. +\t\tTest: %s +\t\tTrain: %s +\t3. Extract the downloaded files. +\t4. Rename `usps` to `train.txt` and, +\t5. Rename `usps.t` to `test.txt +''' % (linkTrain, linkTest) + + if not downloadData(workingDir, downloadDir, linkTrain, linkTest): + exit(failureMsg) + print("Done") diff --git a/tf2.0/examples/Bonsai/helpermethods.py b/tf2.0/examples/Bonsai/helpermethods.py new file mode 100644 index 000000000..febe0613e --- /dev/null +++ b/tf2.0/examples/Bonsai/helpermethods.py @@ -0,0 +1,270 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +''' + Functions to check sanity of input arguments + for the example script. +''' +import argparse +import datetime +import os +import numpy as np + + +def checkIntPos(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive int value" % value) + return ivalue + + +def checkIntNneg(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg int value" % value) + return ivalue + + +def checkFloatNneg(value): + fvalue = float(value) + if fvalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg float value" % value) + return fvalue + + +def checkFloatPos(value): + fvalue = float(value) + if fvalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive float value" % value) + return fvalue + + +def str2bool(v): + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def getArgs(): + ''' + Function to parse arguments for Bonsai Algorithm + ''' + parser = argparse.ArgumentParser( + description='HyperParams for Bonsai Algorithm') + parser.add_argument('-dir', '--data-dir', required=True, + help='Data directory containing' + + 'train.npy and test.npy') + + parser.add_argument('-d', '--depth', type=checkIntNneg, default=2, + help='Depth of Bonsai Tree ' + + '(default: 2 try: [0, 1, 3])') + parser.add_argument('-p', '--proj-dim', type=checkIntPos, default=10, + help='Projection Dimension ' + + '(default: 20 try: [5, 10, 30])') + parser.add_argument('-s', '--sigma', type=float, default=1.0, + help='Parameter for sigmoid sharpness ' + + '(default: 1.0 try: [3.0, 0.05, 0.1]') + parser.add_argument('-e', '--epochs', type=checkIntPos, default=42, + help='Total Epochs (default: 42 try:[100, 150, 60])') + parser.add_argument('-b', '--batch-size', type=checkIntPos, + help='Batch Size to be used ' + + '(default: max(100, sqrt(train_samples)))') + parser.add_argument('-lr', '--learning-rate', type=checkFloatPos, + default=0.01, help='Initial Learning rate for ' + + 'Adam Optimizer (default: 0.01)') + + parser.add_argument('-rW', type=float, default=0.0001, + help='Regularizer for predictor parameter W ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rV', type=float, default=0.0001, + help='Regularizer for predictor parameter V ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rT', type=float, default=0.0001, + help='Regularizer for branching parameter Theta ' + + '(default: 0.0001 try: [0.01, 0.001, 0.00001])') + parser.add_argument('-rZ', type=float, default=0.00001, + help='Regularizer for projection parameter Z ' + + '(default: 0.00001 try: [0.001, 0.0001, 0.000001])') + + parser.add_argument('-sW', type=checkFloatPos, + help='Sparsity for predictor parameter W ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sV', type=checkFloatPos, + help='Sparsity for predictor parameter V ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sT', type=checkFloatPos, + help='Sparsity for branching parameter Theta ' + + '(default: For Binary classification 1.0 else 0.2 ' + + 'try: [0.1, 0.3, 0.5])') + parser.add_argument('-sZ', type=checkFloatPos, default=0.2, + help='Sparsity for projection parameter Z ' + + '(default: 0.2 try: [0.1, 0.3, 0.5])') + parser.add_argument('-oF', '--output-file', default=None, + help='Output file for dumping the program output, ' + + '(default: stdout)') + + parser.add_argument('-regression', type=str2bool, default=False, + help='boolean argument which controls whether to perform ' + + 'regression or classification.' + + 'default : False (Classification) values: [True, False]') + + return parser.parse_args() + + +def getQuantArgs(): + ''' + Function to parse arguments for Model Quantisation + ''' + parser = argparse.ArgumentParser( + description='Arguments for quantizing Fast models. ' + + 'Works only for piece-wise linear non-linearities, ' + + 'like relu, quantTanh, quantSigm (check rnn.py for the definitions)') + parser.add_argument('-dir', '--model-dir', required=True, + help='model directory containing' + + '*.npy weight files dumped from the trained model') + parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127, + help='this represents the maximum possible value ' + + 'in model, essentially the byte complexity, ' + + '127=> 1 byte is default') + + return parser.parse_args() + + +def createTimeStampDir(dataDir): + ''' + Creates a Directory with timestamp as it's name + ''' + if os.path.isdir(dataDir + '/TFBonsaiResults') is False: + try: + os.mkdir(dataDir + '/TFBonsaiResults') + except OSError: + print("Creation of the directory %s failed" % + dataDir + '/TFBonsaiResults') + + currDir = 'TFBonsaiResults/' + datetime.datetime.now().strftime("%H_%M_%S_%d_%m_%y") + if os.path.isdir(dataDir + '/' + currDir) is False: + try: + os.mkdir(dataDir + '/' + currDir) + except OSError: + print("Creation of the directory %s failed" % + dataDir + '/' + currDir) + else: + return (dataDir + '/' + currDir) + return None + + +def preProcessData(dataDir, isRegression=False): + ''' + Function to pre-process input data + Expects a .npy file of form [lbl feats] for each datapoint + Outputs a train and test set datapoints appended with 1 for Bias induction + dataDimension, numClasses are inferred directly + ''' + train = np.load(dataDir + '/train.npy') + test = np.load(dataDir + '/test.npy') + + dataDimension = int(train.shape[1]) - 1 + + Xtrain = train[:, 1:dataDimension + 1] + Ytrain_ = train[:, 0] + + Xtest = test[:, 1:dataDimension + 1] + Ytest_ = test[:, 0] + + # Mean Var Normalisation + mean = np.mean(Xtrain, 0) + std = np.std(Xtrain, 0) + std[std[:] < 0.000001] = 1 + Xtrain = (Xtrain - mean) / std + Xtest = (Xtest - mean) / std + # End Mean Var normalisation + + # Classification. + if (isRegression == False): + numClasses = max(Ytrain_) - min(Ytrain_) + 1 + numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1)) + + lab = Ytrain_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtrain.shape[0], numClasses)) + lab_[np.arange(Xtrain.shape[0]), lab] = 1 + if (numClasses == 2): + Ytrain = np.reshape(lab, [-1, 1]) + else: + Ytrain = lab_ + + lab = Ytest_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtest.shape[0], numClasses)) + lab_[np.arange(Xtest.shape[0]), lab] = 1 + if (numClasses == 2): + Ytest = np.reshape(lab, [-1, 1]) + else: + Ytest = lab_ + + elif (isRegression == True): + # The number of classes is always 1, for regression. + numClasses = 1 + Ytrain = Ytrain_ + Ytest = Ytest_ + + trainBias = np.ones([Xtrain.shape[0], 1]) + Xtrain = np.append(Xtrain, trainBias, axis=1) + testBias = np.ones([Xtest.shape[0], 1]) + Xtest = np.append(Xtest, testBias, axis=1) + + mean = np.append(mean, np.array([0])) + std = np.append(std, np.array([1])) + + if (isRegression == False): + return dataDimension + 1, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std + elif (isRegression == True): + return dataDimension + 1, numClasses, Xtrain, Ytrain.reshape((-1, 1)), Xtest, Ytest.reshape((-1, 1)), mean, std + + +def dumpCommand(list, currDir): + ''' + Dumps the current command to a file for further use + ''' + commandFile = open(currDir + '/command.txt', 'w') + command = "python" + + command = command + " " + ' '.join(list) + commandFile.write(command) + + commandFile.flush() + commandFile.close() + + +def saveMeanStd(mean, std, currDir): + ''' + Function to save Mean and Std vectors + ''' + np.save(currDir + '/mean.npy', mean) + np.save(currDir + '/std.npy', std) + saveMeanStdSeeDot(mean, std, currDir + "/SeeDot") + + +def saveMeanStdSeeDot(mean, std, seeDotDir): + ''' + Function to save Mean and Std vectors + ''' + if os.path.isdir(seeDotDir) is False: + try: + os.mkdir(seeDotDir) + except OSError: + print("Creation of the directory %s failed" % + seeDotDir) + np.savetxt(seeDotDir + '/Mean', mean, delimiter="\t") + np.savetxt(seeDotDir + '/Std', std, delimiter="\t") diff --git a/tf2.0/examples/Bonsai/process_usps.py b/tf2.0/examples/Bonsai/process_usps.py new file mode 100644 index 000000000..252ba11e2 --- /dev/null +++ b/tf2.0/examples/Bonsai/process_usps.py @@ -0,0 +1,54 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Processing the USPS Data. It is assumed that the data is already +# downloaded. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def processData(workingDir, downloadDir): + def loadLibSVMFile(file): + data = load_svmlight_file(file) + features = data[0] + labels = data[1] + retMat = np.zeros([features.shape[0], features.shape[1] + 1]) + retMat[:, 0] = labels + retMat[:, 1:] = features.todense() + return retMat + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + trf = path + '/train.txt' + tsf = path + '/test.txt' + assert os.path.isfile(trf), 'File not found: %s' % trf + assert os.path.isfile(tsf), 'File not found: %s' % tsf + train = loadLibSVMFile(trf) + test = loadLibSVMFile(tsf) + + # Convert the labels from 0 to numClasses-1 + y_train = train[:, 0] + y_test = test[:, 0] + + lab = y_train.astype('uint8') + lab = np.array(lab) - min(lab) + train[:, 0] = lab + + lab = y_test.astype('uint8') + lab = np.array(lab) - min(lab) + test[:, 0] = lab + + np.save(path + '/train.npy', train) + np.save(path + '/test.npy', test) + +if __name__ == '__main__': + # Configuration + workingDir = './' + downloadDir = 'usps10' + # End config + print("Processing data") + processData(workingDir, downloadDir) + print("Done") diff --git a/tf2.0/examples/Bonsai/quantizeBonsaiModels.py b/tf2.0/examples/Bonsai/quantizeBonsaiModels.py new file mode 100644 index 000000000..6ff9f737c --- /dev/null +++ b/tf2.0/examples/Bonsai/quantizeBonsaiModels.py @@ -0,0 +1,72 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import os +import numpy as np + + +def min_max(A, name): + print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A))) + return np.max([np.abs(np.max(A)), np.abs(np.min(A))]) + + +def quantizeFastModels(modelDir, maxValue=127, scalarScaleFactor=1000): + ls = os.listdir(modelDir) + paramNameList = [] + paramWeightList = [] + paramLimitList = [] + + for file in ls: + if file.endswith("npy"): + if file.startswith("mean") or file.startswith("std") or file.startswith("hyperParam"): + continue + else: + paramNameList.append(file) + temp = np.load(modelDir + "/" + file) + paramWeightList.append(temp) + paramLimitList.append(min_max(temp, file)) + + paramLimit = np.max(paramLimitList) + + paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit)) + + quantParamWeights = [] + for param in paramWeightList: + temp = np.round(paramScaleFactor * param) + temp[temp[:] > maxValue] = maxValue + temp[temp[:] < -maxValue] = -1 * (maxValue + 1) + + if maxValue <= 127: + temp = temp.astype('int8') + elif maxValue <= 32767: + temp = temp.astype('int16') + else: + temp = temp.astype('int32') + + quantParamWeights.append(temp) + + if os.path.isdir(modelDir + '/QuantizedTFBonsaiModel') is False: + try: + os.mkdir(modelDir + '/QuantizedTFBonsaiModel') + quantModelDir = modelDir + '/QuantizedTFBonsaiModel' + except OSError: + print("Creation of the directory %s failed" % + modelDir + '/QuantizedFastModel') + + np.save(quantModelDir + "/paramScaleFactor.npy", + paramScaleFactor.astype('int32')) + + for i in range(len(paramNameList)): + np.save(quantModelDir + "/q" + paramNameList[i], quantParamWeights[i]) + + print("\n\nQuantized Model Dir: " + quantModelDir) + + +def main(): + args = helpermethods.getQuantArgs() + quantizeFastModels(args.model_dir, int(args.max_val)) + + +if __name__ == '__main__': + main() diff --git a/tf2.0/examples/FastCells/README.md b/tf2.0/examples/FastCells/README.md new file mode 100644 index 000000000..52b12e6b2 --- /dev/null +++ b/tf2.0/examples/FastCells/README.md @@ -0,0 +1,77 @@ +# EdgeML FastCells on a sample public dataset + +This directory includes example notebook and general execution script of +FastCells (FastRNN & FastGRNN) developed as part of EdgeML along with modified +UGRNN, GRU and LSTM to support the LSQ training routine. +Also, we include a sample cleanup and use-case on the USPS10 public dataset. + +`edgeml.graph.rnn` implements the custom RNN cells of **FastRNN** ([`FastRNNCell`](../../edgeml/graph/rnn.py#L215)) and **FastGRNN** ([`FastGRNNCell`](../../edgeml/graph/rnn.py#L40)) with +multiple additional features like Low-Rank parameterisation, custom +non-linearities etc., Similar to Bonsai and ProtoNN, the three-phase training +routine for FastRNN and FastGRNN is decoupled from the custom cells to +facilitate a plug and play behaviour of the custom RNN cells in other +architectures (NMT, Encoder-Decoder etc.,) in place of the inbuilt `RNNCell`, `GRUCell`, `BasicLSTMCell` etc., +`edgeml.graph.rnn` also contains modified RNN cells of **UGRNN** ([`UGRNNLRCell`](../../edgeml/graph/rnn.py#L862)), +**GRU** ([`GRULRCell`](../../edgeml/graph/rnn.py#L635)) and **LSTM** ([`LSTMLRCell`](../../edgeml/graph/rnn.py#L376)). These cells also can be substituted for FastCells where ever feasible. + +For training FastCells, `edgeml.trainer.fastTrainer` implements the three-phase +FastCell training routine in Tensorflow. A simple example, +`examples/fastcell_example.py` is provided to illustrate its usage. + +Note that `fastcell_example.py` assumes that data is in a specific format. It +is assumed that train and test data is contained in two files, `train.npy` and +`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples, +numberOfFeatures]`. numberOfFeatures is `timesteps x inputDims`, flattened +across timestep dimension. So the input of 1st timestep followed by second and +so on. For an N-Class problem, we assume the labels are integers from 0 +through N-1. Lastly, the training data, `train.npy`, is assumed to well shuffled +as the training routine doesn't shuffle internally. + +**Tested With:** Tensorflow >1.6 with Python 2 and Python 3 + +## Download and clean up sample dataset + +We will be testing out the validation of the code by using the USPS dataset. +The download and cleanup of the dataset to match the above-mentioned format is +done by the script [fetch_usps.py](fetch_usps.py) and +[process_usps.py](process_usps.py) + +``` +python fetch_usps.py +python process_usps.py +``` + + +## Sample command for FastCells on USPS10 +The following sample run on usps10 should validate your library: + +Note: Even though usps10 is not a time-series dataset, it can be assumed as, a time-series where each row is coming in at one single time. +So the number of timesteps = 16 and inputDims = 16 + +```bash +python fastcell_example.py -dir usps10/ -id 16 -hd 32 +``` +This command should give you a final output screen which reads roughly similar to (might not be exact numbers due to various version mismatches): + +``` +Maximum Test accuracy at compressed model size(including early stopping): 0.9407075 at Epoch: 262 +Final Test Accuracy: 0.93721974 + +Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False +``` +`usps10/` directory will now have a consolidated results file called `FastRNNResults.txt` or `FastGRNNResults.txt` depending on the choice of the RNN cell. +A directory `FastRNNResults` or `FastGRNNResults` with the corresponding models with each run of the code on the `usps10` dataset + +## Byte Quantization(Q) for model compression +If you wish to quantize the generated model to use byte quantized integers use `quantizeFastModels.py`. Usage Instructions: + +``` +python quantizeFastModels.py -h +``` + +This will generate quantized models with a suffix of `q` before every param stored in a new directory `QuantizedFastModel` inside the model directory. +One can use this model further on edge devices. + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT license. diff --git a/tf2.0/examples/FastCells/fastcell_example.ipynb b/tf2.0/examples/FastCells/fastcell_example.ipynb new file mode 100644 index 000000000..d1d59ee80 --- /dev/null +++ b/tf2.0/examples/FastCells/fastcell_example.ipynb @@ -0,0 +1,1557 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FastRNN and FastGRNN in Tensorflow\n", + "\n", + "This is a simple notebook that illustrates the usage of Tensorflow implementation of FastRNN and FastGRNN. We are using the USPS dataset. Please refer to `fetch_usps.py` and run it for downloading and cleaning up the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT license.\n", + "\n", + "import helpermethods\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "import sys\n", + "import os\n", + "\n", + "#Provide the GPU number to be used\n", + "os.environ['CUDA_VISIBLE_DEVICES'] =''\n", + "\n", + "#FastRNN and FastGRNN imports\n", + "from edgeml.trainer.fastTrainer import FastTrainer\n", + "from edgeml.graph.rnn import FastGRNNCell\n", + "from edgeml.graph.rnn import FastRNNCell\n", + "from edgeml.graph.rnn import UGRNNLRCell\n", + "from edgeml.graph.rnn import GRULRCell\n", + "from edgeml.graph.rnn import LSTMLRCell\n", + "\n", + "# Fixing seeds for reproducibility\n", + "tf.set_random_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# USPS Data\n", + "\n", + "It is assumed that the USPS data has already been downloaded and processed with [fetch_usps.py](fetch_usps.py) and [process_usps.py](process_usps.py), and is present in the `./usps10` subdirectory.\n", + "\n", + "Note: Even though usps10 is not a time-series dataset, it can be assumed as, a time-series where each row is coming in at one single time.\n", + "So the number of timesteps = 16 and inputDims = 16" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Feature Dimension: 256\n", + "Num classes: 10\n" + ] + } + ], + "source": [ + "#Loading and Pre-processing dataset for FastCells\n", + "dataDir = \"usps10\"\n", + "(dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std) = helpermethods.preProcessData(dataDir)\n", + "print(\"Feature Dimension: \", dataDimension)\n", + "print(\"Num classes: \", numClasses)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Parameters\n", + "\n", + "FastRNN and FastGRNN work for most of the hyper-parameters with which you could acheive decent accuracies on LSTM/GRU. Over and above that, you can use low-rank, sparsity and quatization to reduce model size upto 45x when compared to LSTM/GRU." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "cell = \"FastGRNN\" # Choose between FastGRNN, FastRNN, UGRNN, GRU and LSTM\n", + "\n", + "inputDims = 16 #features taken in by RNN in one timestep\n", + "hiddenDims = 32 #hidden state of RNN\n", + "\n", + "totalEpochs = 300\n", + "batchSize = 100\n", + "\n", + "learningRate = 0.01\n", + "decayStep = 200\n", + "decayRate = 0.1\n", + "\n", + "outFile = None #provide your file, if you need all the logging info in a file\n", + "\n", + "#low-rank parameterisation for weight matrices. None => Full Rank\n", + "wRank = None \n", + "uRank = None \n", + "\n", + "#Sparsity of the weight matrices. x => 100*x % are non-zeros\n", + "sW = 1.0 \n", + "sU = 1.0\n", + "\n", + "#Non-linearities for the RNN architecture. Can choose from \"tanh, sigmoid, relu, quantTanh, quantSigm\"\n", + "update_non_linearity = \"tanh\"\n", + "gate_non_linearity = \"sigmoid\"\n", + "\n", + "assert dataDimension % inputDims == 0, \"Infeasible per step input, Timesteps have to be integer\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Placeholders for Data feeding during training and infernece" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "X = tf.placeholder(\"float\", [None, int(dataDimension / inputDims), inputDims])\n", + "Y = tf.placeholder(\"float\", [None, numClasses])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Creating a directory for current model in the datadirectory using timestamp" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "currDir = helpermethods.createTimeStampDir(dataDir, cell)\n", + "helpermethods.dumpCommand(sys.argv, currDir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FastCell Graph Object\n", + "\n", + "Instantiating the FastCell Graph using modular RNN Cells which will be used for training and inference.\n", + "\n", + "Note: RNN cells in edgeml.rnn can be used anywhere in place of LSTM/GRU in a plug & play fashion." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "#Create appropriate RNN cell object based on choice\n", + "if cell == \"FastGRNN\":\n", + " FastCell = FastGRNNCell(hiddenDims, gate_non_linearity=gate_non_linearity,\n", + " update_non_linearity=update_non_linearity,\n", + " wRank=wRank, uRank=uRank)\n", + "elif cell == \"FastRNN\":\n", + " FastCell = FastRNNCell(hiddenDims, update_non_linearity=update_non_linearity,\n", + " wRank=wRank, uRank=uRank)\n", + "elif cell == \"UGRNN\":\n", + " FastCell = UGRNNLRCell(hiddenDims, update_non_linearity=update_non_linearity,\n", + " wRank=wRank, uRank=uRank)\n", + "elif cell == \"GRU\":\n", + " FastCell = GRULRCell(hiddenDims, update_non_linearity=update_non_linearity,\n", + " wRank=wRank, uRank=uRank)\n", + "elif cell == \"LSTM\":\n", + " FastCell = LSTMLRCell(hiddenDims, update_non_linearity=update_non_linearity,\n", + " wRank=wRank, uRank=uRank)\n", + "else:\n", + " sys.exit('Exiting: No Such Cell as ' + cell)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FastCell Trainer Object\n", + "\n", + "Instantiating the FastCell Trainer which will be used for 3 phase training" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "FastCellTrainer = FastTrainer(FastCell, X, Y, sW=sW, sU=sU, learningRate=learningRate, outFile=outFile)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Session declaration and variable initialization. Interactive Session doesn't clog the entire GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "sess = tf.InteractiveSession()\n", + "sess.run(tf.global_variables_initializer())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# FastCell Training Routine\n", + "\n", + "The method to to run the 3 phase training, followed by giving out the best early stopping model, accuracy along with saving of the parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 0\n", + "\n", + "******************** Dense Training Phase Started ********************\n", + "\n", + "Train Loss: 1.3531070024999854 Train Accuracy: 0.565881378744563\n", + "Test Loss: 0.8334901 Test Accuracy: 0.7349278\n", + "\n", + "Epoch Number: 1\n", + "Train Loss: 0.5264064224615489 Train Accuracy: 0.8227005854044875\n", + "Test Loss: 0.52811986 Test Accuracy: 0.83557546\n", + "\n", + "Epoch Number: 2\n", + "Train Loss: 0.3170111432467421 Train Accuracy: 0.8997546287432109\n", + "Test Loss: 0.41971388 Test Accuracy: 0.87593424\n", + "\n", + "Epoch Number: 3\n", + "Train Loss: 0.22838621382435706 Train Accuracy: 0.9285217539904869\n", + "Test Loss: 0.37176716 Test Accuracy: 0.8943697\n", + "\n", + "Epoch Number: 4\n", + "Train Loss: 0.17584358977332507 Train Accuracy: 0.9436173479850978\n", + "Test Loss: 0.3482268 Test Accuracy: 0.9013453\n", + "\n", + "Epoch Number: 5\n", + "Train Loss: 0.1554100387921072 Train Accuracy: 0.9503703141865665\n", + "Test Loss: 0.36468038 Test Accuracy: 0.8963627\n", + "\n", + "Epoch Number: 6\n", + "Train Loss: 0.13128593576791353 Train Accuracy: 0.9591509887616928\n", + "Test Loss: 0.36238122 Test Accuracy: 0.9028401\n", + "\n", + "Epoch Number: 7\n", + "Train Loss: 0.11856559077150201 Train Accuracy: 0.9623016780369902\n", + "Test Loss: 0.37148365 Test Accuracy: 0.9003488\n", + "\n", + "Epoch Number: 8\n", + "Train Loss: 0.11480801381579 Train Accuracy: 0.9623016764039862\n", + "Test Loss: 0.40140042 Test Accuracy: 0.8938714\n", + "\n", + "Epoch Number: 9\n", + "Train Loss: 0.11065655635440186 Train Accuracy: 0.9653153754260442\n", + "Test Loss: 0.3517686 Test Accuracy: 0.90981567\n", + "\n", + "Epoch Number: 10\n", + "Train Loss: 0.09199796772676788 Train Accuracy: 0.9716302948455288\n", + "Test Loss: 0.3499246 Test Accuracy: 0.9147982\n", + "\n", + "Epoch Number: 11\n", + "Train Loss: 0.07985301451017596 Train Accuracy: 0.9762742788824317\n", + "Test Loss: 0.3625236 Test Accuracy: 0.91529644\n", + "\n", + "Epoch Number: 12\n", + "Train Loss: 0.07171525779397112 Train Accuracy: 0.9787535806224771\n", + "Test Loss: 0.35705435 Test Accuracy: 0.91429996\n", + "\n", + "Epoch Number: 13\n", + "Train Loss: 0.077431221046064 Train Accuracy: 0.9755893504782899\n", + "Test Loss: 0.38592914 Test Accuracy: 0.9093174\n", + "\n", + "Epoch Number: 14\n", + "Train Loss: 0.07726132686007513 Train Accuracy: 0.9744799128950459\n", + "Test Loss: 0.38768652 Test Accuracy: 0.9123069\n", + "\n", + "Epoch Number: 15\n", + "Train Loss: 0.06339540997239416 Train Accuracy: 0.9798494748873253\n", + "Test Loss: 0.36402556 Test Accuracy: 0.9197808\n", + "\n", + "Epoch Number: 16\n", + "Train Loss: 0.0624726844173282 Train Accuracy: 0.9810823528733972\n", + "Test Loss: 0.3556986 Test Accuracy: 0.9192825\n", + "\n", + "Epoch Number: 17\n", + "Train Loss: 0.05848091944082551 Train Accuracy: 0.9821376008530186\n", + "Test Loss: 0.3734596 Test Accuracy: 0.922272\n", + "\n", + "Epoch Number: 18\n", + "Train Loss: 0.06179975296613084 Train Accuracy: 0.9775207050859112\n", + "Test Loss: 0.37375587 Test Accuracy: 0.9147982\n", + "\n", + "Epoch Number: 19\n", + "Train Loss: 0.060816061236474615 Train Accuracy: 0.980534406557475\n", + "Test Loss: 0.36386096 Test Accuracy: 0.92077726\n", + "\n", + "Epoch Number: 20\n", + "Train Loss: 0.05517878877126599 Train Accuracy: 0.9829866126792072\n", + "Test Loss: 0.38278854 Test Accuracy: 0.92077726\n", + "\n", + "Epoch Number: 21\n", + "Train Loss: 0.04950164187036148 Train Accuracy: 0.9835481072125369\n", + "Test Loss: 0.38189712 Test Accuracy: 0.91878426\n", + "\n", + "Epoch Number: 22\n", + "Train Loss: 0.04603105507893105 Train Accuracy: 0.984219489848777\n", + "Test Loss: 0.39881724 Test Accuracy: 0.9123069\n", + "\n", + "Epoch Number: 23\n", + "Train Loss: 0.04120528124183519 Train Accuracy: 0.985726339359806\n", + "Test Loss: 0.41953668 Test Accuracy: 0.91131043\n", + "\n", + "Epoch Number: 24\n", + "Train Loss: 0.04223672329282312 Train Accuracy: 0.9858497748636219\n", + "Test Loss: 0.37599987 Test Accuracy: 0.9227703\n", + "\n", + "Epoch Number: 25\n", + "Train Loss: 0.044115278812457026 Train Accuracy: 0.9849044190694208\n", + "Test Loss: 0.39963064 Test Accuracy: 0.92127556\n", + "\n", + "Epoch Number: 26\n", + "Train Loss: 0.060125956299064094 Train Accuracy: 0.9792608863686862\n", + "Test Loss: 0.39676014 Test Accuracy: 0.91131043\n", + "\n", + "Epoch Number: 27\n", + "Train Loss: 0.058513890084338514 Train Accuracy: 0.9795484101935609\n", + "Test Loss: 0.3695973 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 28\n", + "Train Loss: 0.04882802803401057 Train Accuracy: 0.9824115707449717\n", + "Test Loss: 0.4062322 Test Accuracy: 0.9128052\n", + "\n", + "Epoch Number: 29\n", + "Train Loss: 0.04246805129853422 Train Accuracy: 0.9854659160522565\n", + "Test Loss: 0.36979795 Test Accuracy: 0.92526156\n", + "\n", + "Epoch Number: 30\n", + "Train Loss: 0.05128337493906283 Train Accuracy: 0.9843700242369142\n", + "Test Loss: 0.4025077 Test Accuracy: 0.9172895\n", + "\n", + "Epoch Number: 31\n", + "Train Loss: 0.04524477290895398 Train Accuracy: 0.9840825028615455\n", + "Test Loss: 0.36316648 Test Accuracy: 0.9227703\n", + "\n", + "Epoch Number: 32\n", + "Train Loss: 0.04791155387966396 Train Accuracy: 0.9839319660239023\n", + "Test Loss: 0.38224837 Test Accuracy: 0.9197808\n", + "\n", + "Epoch Number: 33\n", + "Train Loss: 0.04305804770261253 Train Accuracy: 0.98493151713724\n", + "Test Loss: 0.3597 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 34\n", + "Train Loss: 0.03439056758819888 Train Accuracy: 0.9891509944445467\n", + "Test Loss: 0.36144 Test Accuracy: 0.92326856\n", + "\n", + "Epoch Number: 35\n", + "Train Loss: 0.025825574640057063 Train Accuracy: 0.9935481017583037\n", + "Test Loss: 0.3576532 Test Accuracy: 0.9287494\n", + "\n", + "Epoch Number: 36\n", + "Train Loss: 0.020732127933775726 Train Accuracy: 0.9947809772948696\n", + "Test Loss: 0.3529356 Test Accuracy: 0.92825115\n", + "\n", + "Epoch Number: 37\n", + "Train Loss: 0.02256068464189972 Train Accuracy: 0.9938356215006685\n", + "Test Loss: 0.3675873 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 38\n", + "Train Loss: 0.04096006839025817 Train Accuracy: 0.9857398875772136\n", + "Test Loss: 0.36569017 Test Accuracy: 0.9267564\n", + "\n", + "Epoch Number: 39\n", + "Train Loss: 0.04014190110339694 Train Accuracy: 0.9867123389897281\n", + "Test Loss: 0.34677818 Test Accuracy: 0.9262581\n", + "\n", + "Epoch Number: 40\n", + "Train Loss: 0.031071233378136404 Train Accuracy: 0.9899864605028336\n", + "Test Loss: 0.363686 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 41\n", + "Train Loss: 0.02729316997303538 Train Accuracy: 0.9908219265611204\n", + "Test Loss: 0.35555694 Test Accuracy: 0.9312407\n", + "\n", + "Epoch Number: 42\n", + "Train Loss: 0.021803765849542026 Train Accuracy: 0.992191786635412\n", + "Test Loss: 0.35095477 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 43\n", + "Train Loss: 0.04842862480460373 Train Accuracy: 0.9833975695583919\n", + "Test Loss: 0.42905322 Test Accuracy: 0.91679126\n", + "\n", + "Epoch Number: 44\n", + "Train Loss: 0.04453416636264692 Train Accuracy: 0.9834111210418074\n", + "Test Loss: 0.406023 Test Accuracy: 0.920279\n", + "\n", + "Epoch Number: 45\n", + "Train Loss: 0.038877726283740914 Train Accuracy: 0.9870962010671015\n", + "Test Loss: 0.39293337 Test Accuracy: 0.91878426\n", + "\n", + "Epoch Number: 46\n", + "Train Loss: 0.034626684416315126 Train Accuracy: 0.9884796118083066\n", + "Test Loss: 0.36277694 Test Accuracy: 0.9237668\n", + "\n", + "Epoch Number: 47\n", + "Train Loss: 0.02302065390889367 Train Accuracy: 0.9934111139545702\n", + "Test Loss: 0.38474992 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 48\n", + "Train Loss: 0.023432086993723292 Train Accuracy: 0.9943564705652733\n", + "Test Loss: 0.370669 Test Accuracy: 0.9237668\n", + "\n", + "Epoch Number: 49\n", + "Train Loss: 0.024380253930097726 Train Accuracy: 0.9921782384180042\n", + "Test Loss: 0.40583202 Test Accuracy: 0.9227703\n", + "\n", + "Epoch Number: 50\n", + "Train Loss: 0.023330659918129854 Train Accuracy: 0.9926027467806046\n", + "Test Loss: 0.4097609 Test Accuracy: 0.92575985\n", + "\n", + "Epoch Number: 51\n", + "Train Loss: 0.018314683679108545 Train Accuracy: 0.9943835661835867\n", + "Test Loss: 0.38972235 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 52\n", + "Train Loss: 0.029633181783600315 Train Accuracy: 0.9905344043692498\n", + "Test Loss: 0.37864792 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 53\n", + "Train Loss: 0.030011002424058235 Train Accuracy: 0.9905479509536534\n", + "Test Loss: 0.3964535 Test Accuracy: 0.9192825\n", + "\n", + "Epoch Number: 54\n", + "Train Loss: 0.03564942483343694 Train Accuracy: 0.9888499256682722\n", + "Test Loss: 0.38546467 Test Accuracy: 0.92326856\n", + "\n", + "Epoch Number: 55\n", + "Train Loss: 0.0320119748230105 Train Accuracy: 0.9893015280161819\n", + "Test Loss: 0.41079342 Test Accuracy: 0.91679126\n", + "\n", + "Epoch Number: 56\n", + "Train Loss: 0.027233783602204225 Train Accuracy: 0.9919042677095492\n", + "Test Loss: 0.40080228 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 57\n", + "Train Loss: 0.0170260006386455 Train Accuracy: 0.9949044152481915\n", + "Test Loss: 0.42503983 Test Accuracy: 0.9292476\n", + "\n", + "Epoch Number: 58\n", + "Train Loss: 0.020110745480513736 Train Accuracy: 0.9946575393415478\n", + "Test Loss: 0.38848647 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 59\n", + "Train Loss: 0.015590530762780611 Train Accuracy: 0.9949179634655991\n", + "Test Loss: 0.4031199 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 60\n", + "Train Loss: 0.022963548624530844 Train Accuracy: 0.992863170088154\n", + "Test Loss: 0.42644864 Test Accuracy: 0.9197808\n", + "\n", + "Epoch Number: 61\n", + "Train Loss: 0.024166807283532536 Train Accuracy: 0.9914933091973606\n", + "Test Loss: 0.4117787 Test Accuracy: 0.9247633\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 62\n", + "Train Loss: 0.01902851595364715 Train Accuracy: 0.9946575393415478\n", + "Test Loss: 0.43569365 Test Accuracy: 0.918286\n", + "\n", + "Epoch Number: 63\n", + "Train Loss: 0.022098659849502402 Train Accuracy: 0.9919178142939529\n", + "Test Loss: 0.4453173 Test Accuracy: 0.92575985\n", + "\n", + "Epoch Number: 64\n", + "Train Loss: 0.02353779313932747 Train Accuracy: 0.9930001562588835\n", + "Test Loss: 0.43414015 Test Accuracy: 0.91429996\n", + "\n", + "Epoch Number: 65\n", + "Train Loss: 0.016468530626048986 Train Accuracy: 0.9947809764783676\n", + "Test Loss: 0.43052217 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 66\n", + "Train Loss: 0.016379667304086257 Train Accuracy: 0.9958904148781136\n", + "Test Loss: 0.4004999 Test Accuracy: 0.92825115\n", + "\n", + "Epoch Number: 67\n", + "Train Loss: 0.012232361819072026 Train Accuracy: 0.9971232904146795\n", + "Test Loss: 0.40298688 Test Accuracy: 0.93273544\n", + "\n", + "Epoch Number: 68\n", + "Train Loss: 0.008708359920403806 Train Accuracy: 0.998493152121975\n", + "Test Loss: 0.42018083 Test Accuracy: 0.9272546\n", + "\n", + "Epoch Number: 69\n", + "Train Loss: 0.009453040786081134 Train Accuracy: 0.9979452074390568\n", + "Test Loss: 0.42367473 Test Accuracy: 0.9287494\n", + "\n", + "Epoch Number: 70\n", + "Train Loss: 0.02633900548393634 Train Accuracy: 0.9916031989332748\n", + "Test Loss: 0.4314282 Test Accuracy: 0.91778773\n", + "\n", + "Epoch Number: 71\n", + "Train Loss: 0.05996186181277751 Train Accuracy: 0.9832605866536702\n", + "Test Loss: 0.40858173 Test Accuracy: 0.9227703\n", + "\n", + "Epoch Number: 72\n", + "Train Loss: 0.03984937108479032 Train Accuracy: 0.9866987883228145\n", + "Test Loss: 0.41035435 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 73\n", + "Train Loss: 0.024671344705283232 Train Accuracy: 0.991356323026631\n", + "Test Loss: 0.42347214 Test Accuracy: 0.92575985\n", + "\n", + "Epoch Number: 74\n", + "Train Loss: 0.0261542204694108 Train Accuracy: 0.9923287736226435\n", + "Test Loss: 0.40737543 Test Accuracy: 0.92127556\n", + "\n", + "Epoch Number: 75\n", + "Train Loss: 0.021734511994833304 Train Accuracy: 0.9932741294168446\n", + "Test Loss: 0.38865966 Test Accuracy: 0.93024415\n", + "\n", + "Epoch Number: 76\n", + "Train Loss: 0.017603178070028862 Train Accuracy: 0.9942330326119514\n", + "Test Loss: 0.41292053 Test Accuracy: 0.9272546\n", + "\n", + "Epoch Number: 77\n", + "Train Loss: 0.01774944902368987 Train Accuracy: 0.9949179634655991\n", + "Test Loss: 0.3975856 Test Accuracy: 0.9287494\n", + "\n", + "Epoch Number: 78\n", + "Train Loss: 0.026556726039565895 Train Accuracy: 0.9926027451476006\n", + "Test Loss: 0.37759724 Test Accuracy: 0.9262581\n", + "\n", + "Epoch Number: 79\n", + "Train Loss: 0.03763930009971436 Train Accuracy: 0.9888905711369972\n", + "Test Loss: 0.44735578 Test Accuracy: 0.9172895\n", + "\n", + "Epoch Number: 80\n", + "Train Loss: 0.029481777800119496 Train Accuracy: 0.9901369957074727\n", + "Test Loss: 0.41876832 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 81\n", + "Train Loss: 0.02179017917999411 Train Accuracy: 0.9924522115759653\n", + "Test Loss: 0.41835007 Test Accuracy: 0.920279\n", + "\n", + "Epoch Number: 82\n", + "Train Loss: 0.0234184642127007 Train Accuracy: 0.9921646918336006\n", + "Test Loss: 0.41416502 Test Accuracy: 0.9272546\n", + "\n", + "Epoch Number: 83\n", + "Train Loss: 0.02082834580166852 Train Accuracy: 0.993260580382935\n", + "Test Loss: 0.4422068 Test Accuracy: 0.9172895\n", + "\n", + "Epoch Number: 84\n", + "Train Loss: 0.022050149352465156 Train Accuracy: 0.9939726084879\n", + "Test Loss: 0.3987477 Test Accuracy: 0.92825115\n", + "\n", + "Epoch Number: 85\n", + "Train Loss: 0.026048276352549405 Train Accuracy: 0.9927261847339265\n", + "Test Loss: 0.38845396 Test Accuracy: 0.9272546\n", + "\n", + "Epoch Number: 86\n", + "Train Loss: 0.01715031830029409 Train Accuracy: 0.9952054840244658\n", + "Test Loss: 0.3792558 Test Accuracy: 0.9267564\n", + "\n", + "Epoch Number: 87\n", + "Train Loss: 0.014544817494636732 Train Accuracy: 0.9964248113436242\n", + "Test Loss: 0.41980278 Test Accuracy: 0.9242651\n", + "\n", + "Epoch Number: 88\n", + "Train Loss: 0.006491333439193462 Train Accuracy: 0.9987671244634341\n", + "Test Loss: 0.39655796 Test Accuracy: 0.9317389\n", + "\n", + "Epoch Number: 89\n", + "Train Loss: 0.004307456604007325 Train Accuracy: 0.9993150691463523\n", + "Test Loss: 0.40233433 Test Accuracy: 0.92825115\n", + "\n", + "Epoch Number: 90\n", + "Train Loss: 0.0027800448436596215 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.403938 Test Accuracy: 0.9287494\n", + "\n", + "Epoch Number: 91\n", + "Train Loss: 0.002242555237002033 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.4070075 Test Accuracy: 0.9307424\n", + "\n", + "Epoch Number: 92\n", + "Train Loss: 0.0022119151863703277 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.41036773 Test Accuracy: 0.9307424\n", + "\n", + "Epoch Number: 93\n", + "Train Loss: 0.001824945211809205 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.41361076 Test Accuracy: 0.9317389\n", + "\n", + "Epoch Number: 94\n", + "Train Loss: 0.001808816738895536 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.41818038 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 95\n", + "Train Loss: 0.0015898832340871482 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.4229942 Test Accuracy: 0.9312407\n", + "\n", + "Epoch Number: 96\n", + "Train Loss: 0.001751650427787067 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.42656386 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 97\n", + "Train Loss: 0.0015788370674023125 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.43016008 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 98\n", + "Train Loss: 0.0016806908206988688 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.43488127 Test Accuracy: 0.9297459\n", + "\n", + "Epoch Number: 99\n", + "Train Loss: 0.0015810940553009996 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.43672302 Test Accuracy: 0.9307424\n", + "\n", + "Epoch Number: 100\n", + "Train Loss: 0.001932646346052037 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.4472658 Test Accuracy: 0.9292476\n", + "\n", + "Epoch Number: 101\n", + "Train Loss: 0.03748205324996914 Train Accuracy: 0.9887535817002597\n", + "Test Loss: 0.44853446 Test Accuracy: 0.9147982\n", + "\n", + "Epoch Number: 102\n", + "Train Loss: 0.11106032654898215 Train Accuracy: 0.9678774721001926\n", + "Test Loss: 0.36246482 Test Accuracy: 0.9227703\n", + "\n", + "Epoch Number: 103\n", + "Train Loss: 0.05241849743569755 Train Accuracy: 0.9829866134957091\n", + "Test Loss: 0.3570766 Test Accuracy: 0.92127556\n", + "\n", + "Epoch Number: 104\n", + "Train Loss: 0.028517094278095723 Train Accuracy: 0.9917672799058157\n", + "Test Loss: 0.37065578 Test Accuracy: 0.9292476\n", + "\n", + "Epoch Number: 105\n", + "Train Loss: 0.017652338973488915 Train Accuracy: 0.9942330326119514\n", + "Test Loss: 0.37107578 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 106\n", + "Train Loss: 0.01568616884250245 Train Accuracy: 0.9952054840244658\n", + "Test Loss: 0.35663217 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 107\n", + "Train Loss: 0.017838217169748084 Train Accuracy: 0.9951919358070582\n", + "Test Loss: 0.41287652 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 108\n", + "Train Loss: 0.010033470020734717 Train Accuracy: 0.9976712350975977\n", + "Test Loss: 0.4009026 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 109\n", + "Train Loss: 0.008324489026961925 Train Accuracy: 0.9980686453923787\n", + "Test Loss: 0.4033402 Test Accuracy: 0.9297459\n", + "\n", + "Epoch Number: 110\n", + "Train Loss: 0.00771069963668371 Train Accuracy: 0.9986165900752969\n", + "Test Loss: 0.41503954 Test Accuracy: 0.93223715\n", + "\n", + "Epoch Number: 111\n", + "Train Loss: 0.017172634716413608 Train Accuracy: 0.9951919358070582\n", + "Test Loss: 0.43192545 Test Accuracy: 0.9237668\n", + "\n", + "Epoch Number: 112\n", + "Train Loss: 0.03227482749984842 Train Accuracy: 0.9877946801381569\n", + "Test Loss: 0.44748837 Test Accuracy: 0.91579473\n", + "\n", + "Epoch Number: 113\n", + "Train Loss: 0.03210181032135215 Train Accuracy: 0.9895890493915506\n", + "Test Loss: 0.4282059 Test Accuracy: 0.92127556\n", + "\n", + "Epoch Number: 114\n", + "Train Loss: 0.01357129268182365 Train Accuracy: 0.9964383595610318\n", + "Test Loss: 0.394982 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 115\n", + "Train Loss: 0.019958787539508194 Train Accuracy: 0.9939455112365827\n", + "Test Loss: 0.44919127 Test Accuracy: 0.92326856\n", + "\n", + "Epoch Number: 116\n", + "Train Loss: 0.01951833989204877 Train Accuracy: 0.9941095938421276\n", + "Test Loss: 0.39846456 Test Accuracy: 0.9262581\n", + "\n", + "Epoch Number: 117\n", + "Train Loss: 0.013109919135873397 Train Accuracy: 0.9961643872195727\n", + "Test Loss: 0.3964593 Test Accuracy: 0.93273544\n", + "\n", + "Epoch Number: 118\n", + "Train Loss: 0.008196171877063708 Train Accuracy: 0.9978082212683272\n", + "Test Loss: 0.39881837 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 119\n", + "Train Loss: 0.0053880620705544285 Train Accuracy: 0.9991780829756227\n", + "Test Loss: 0.3994898 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 120\n", + "Train Loss: 0.004016872985792436 Train Accuracy: 0.9990410968048932\n", + "Test Loss: 0.40483266 Test Accuracy: 0.93821627\n", + "\n", + "Epoch Number: 121\n", + "Train Loss: 0.003189845641752807 Train Accuracy: 0.9994520553170818\n", + "Test Loss: 0.4129637 Test Accuracy: 0.9317389\n", + "\n", + "Epoch Number: 122\n", + "Train Loss: 0.0019790745577105157 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.41090903 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 123\n", + "Train Loss: 0.001803544777754873 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41632676 Test Accuracy: 0.9362232\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 124\n", + "Train Loss: 0.0015932139348516189 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.4208521 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 125\n", + "Train Loss: 0.0016064488660697252 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.4273708 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 126\n", + "Train Loss: 0.0015048502509048438 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.43023735 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 127\n", + "Train Loss: 0.0014419755101050824 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.4389877 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 128\n", + "Train Loss: 0.0013684726919028398 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.44143116 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 129\n", + "Train Loss: 0.0013124902181690943 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.44684827 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 130\n", + "Train Loss: 0.001271455863264249 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.44386175 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 131\n", + "Train Loss: 0.0013829727382247642 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.45779392 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 132\n", + "Train Loss: 0.0019963769629288285 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.45586306 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 133\n", + "Train Loss: 0.0016972724874966382 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.4626169 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 134\n", + "Train Loss: 0.08464051221679954 Train Accuracy: 0.9767951271305345\n", + "Test Loss: 0.5116203 Test Accuracy: 0.89287496\n", + "\n", + "Epoch Number: 135\n", + "Train Loss: 0.06996250499282287 Train Accuracy: 0.9749044163586342\n", + "Test Loss: 0.40689448 Test Accuracy: 0.9242651\n", + "\n", + "Epoch Number: 136\n", + "Train Loss: 0.03188256850970067 Train Accuracy: 0.9891509952610487\n", + "Test Loss: 0.38704544 Test Accuracy: 0.9262581\n", + "\n", + "Epoch Number: 137\n", + "Train Loss: 0.024611471771032945 Train Accuracy: 0.9905479534031594\n", + "Test Loss: 0.37838554 Test Accuracy: 0.92526156\n", + "\n", + "Epoch Number: 138\n", + "Train Loss: 0.010400510262315199 Train Accuracy: 0.9975342489268682\n", + "Test Loss: 0.37540606 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 139\n", + "Train Loss: 0.01360704386503155 Train Accuracy: 0.9961643864030707\n", + "Test Loss: 0.3835624 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 140\n", + "Train Loss: 0.011857587645589437 Train Accuracy: 0.9971097421972719\n", + "Test Loss: 0.38600543 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 141\n", + "Train Loss: 0.009779149474263548 Train Accuracy: 0.9979452074390568\n", + "Test Loss: 0.3920699 Test Accuracy: 0.9292476\n", + "\n", + "Epoch Number: 142\n", + "Train Loss: 0.011075850638356826 Train Accuracy: 0.9968493180732204\n", + "Test Loss: 0.39335945 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 143\n", + "Train Loss: 0.007224232131018852 Train Accuracy: 0.9982191797805159\n", + "Test Loss: 0.39289063 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 144\n", + "Train Loss: 0.00687252213119542 Train Accuracy: 0.9986165900752969\n", + "Test Loss: 0.41478387 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 145\n", + "Train Loss: 0.0036116211602547246 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.38912663 Test Accuracy: 0.9412058\n", + "\n", + "Epoch Number: 146\n", + "Train Loss: 0.002582093849076494 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38982612 Test Accuracy: 0.93921274\n", + "\n", + "Epoch Number: 147\n", + "Train Loss: 0.0020956868141943793 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3942846 Test Accuracy: 0.9407075\n", + "\n", + "Epoch Number: 148\n", + "Train Loss: 0.0018172568598943954 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.40042797 Test Accuracy: 0.93971103\n", + "\n", + "Epoch Number: 149\n", + "Train Loss: 0.004224563434110852 Train Accuracy: 0.9993150691463523\n", + "Test Loss: 0.41389707 Test Accuracy: 0.9387145\n", + "\n", + "Epoch Number: 150\n", + "Train Loss: 0.0033937980600693327 Train Accuracy: 0.9991780829756227\n", + "Test Loss: 0.4309551 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 151\n", + "Train Loss: 0.016683471513902513 Train Accuracy: 0.995753428707384\n", + "Test Loss: 0.5002462 Test Accuracy: 0.92326856\n", + "\n", + "Epoch Number: 152\n", + "Train Loss: 0.07921012418587016 Train Accuracy: 0.9773566224803664\n", + "Test Loss: 0.49312237 Test Accuracy: 0.91380167\n", + "\n", + "Epoch Number: 153\n", + "Train Loss: 0.06044871790321824 Train Accuracy: 0.9809589198190872\n", + "Test Loss: 0.39408478 Test Accuracy: 0.918286\n", + "\n", + "Epoch Number: 154\n", + "Train Loss: 0.03367851439015047 Train Accuracy: 0.9889041193544048\n", + "Test Loss: 0.37989596 Test Accuracy: 0.9307424\n", + "\n", + "Epoch Number: 155\n", + "Train Loss: 0.017698209322925196 Train Accuracy: 0.9939319638356771\n", + "Test Loss: 0.37573016 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 156\n", + "Train Loss: 0.010081476129253383 Train Accuracy: 0.9975342489268682\n", + "Test Loss: 0.38967404 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 157\n", + "Train Loss: 0.00447188632057227 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.38149583 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 158\n", + "Train Loss: 0.0025207580036120105 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38982794 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 159\n", + "Train Loss: 0.0020615914665092394 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38763282 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 160\n", + "Train Loss: 0.0017625982677787286 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.39443132 Test Accuracy: 0.93821627\n", + "\n", + "Epoch Number: 161\n", + "Train Loss: 0.0015912198259061432 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.39672822 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 162\n", + "Train Loss: 0.0015430196065994693 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.40416223 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 163\n", + "Train Loss: 0.0014482194389143245 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.40405902 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 164\n", + "Train Loss: 0.001342231735877361 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41373003 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 165\n", + "Train Loss: 0.001386067051797697 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41112974 Test Accuracy: 0.9387145\n", + "\n", + "Epoch Number: 166\n", + "Train Loss: 0.001193268498392259 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.42294186 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 167\n", + "Train Loss: 0.0012895042981317727 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41742194 Test Accuracy: 0.9387145\n", + "\n", + "Epoch Number: 168\n", + "Train Loss: 0.001108311818077784 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.4331248 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 169\n", + "Train Loss: 0.001382393570304274 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.42247245 Test Accuracy: 0.93971103\n", + "\n", + "Epoch Number: 170\n", + "Train Loss: 0.0012338795040954184 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.45070615 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 171\n", + "Train Loss: 0.002104504165289069 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.4341877 Test Accuracy: 0.93921274\n", + "\n", + "Epoch Number: 172\n", + "Train Loss: 0.004313505504460534 Train Accuracy: 0.9984796039045674\n", + "Test Loss: 0.46030065 Test Accuracy: 0.9297459\n", + "\n", + "Epoch Number: 173\n", + "Train Loss: 0.093728030188175 Train Accuracy: 0.9727126351774555\n", + "Test Loss: 0.42030093 Test Accuracy: 0.91629297\n", + "\n", + "Epoch Number: 174\n", + "Train Loss: 0.05609176247635831 Train Accuracy: 0.9824522162136966\n", + "Test Loss: 0.41325808 Test Accuracy: 0.92127556\n", + "\n", + "Epoch Number: 175\n", + "Train Loss: 0.02312381442338037 Train Accuracy: 0.9921782392345063\n", + "Test Loss: 0.4045683 Test Accuracy: 0.9242651\n", + "\n", + "Epoch Number: 176\n", + "Train Loss: 0.014690037730647481 Train Accuracy: 0.9953424693786934\n", + "Test Loss: 0.3937127 Test Accuracy: 0.9307424\n", + "\n", + "Epoch Number: 177\n", + "Train Loss: 0.00893136704809428 Train Accuracy: 0.9982191797805159\n", + "Test Loss: 0.39721966 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 178\n", + "Train Loss: 0.007236258619168314 Train Accuracy: 0.9989041106341636\n", + "Test Loss: 0.38853252 Test Accuracy: 0.92775285\n", + "\n", + "Epoch Number: 179\n", + "Train Loss: 0.004083093933518721 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.37919173 Test Accuracy: 0.9312407\n", + "\n", + "Epoch Number: 180\n", + "Train Loss: 0.002494418898705801 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.39079925 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 181\n", + "Train Loss: 0.001869901260624206 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3961389 Test Accuracy: 0.9312407\n", + "\n", + "Epoch Number: 182\n", + "Train Loss: 0.0017469667874225607 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.4000474 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 183\n", + "Train Loss: 0.0012739899573043908 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.40910247 Test Accuracy: 0.93273544\n", + "\n", + "Epoch Number: 184\n", + "Train Loss: 0.0013601894353672684 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41040978 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 185\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Loss: 0.0011997123000495875 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41469583 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 186\n", + "Train Loss: 0.0014707065065397741 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.41498712 Test Accuracy: 0.94020927\n", + "\n", + "Epoch Number: 187\n", + "Train Loss: 0.002836034407166203 Train Accuracy: 0.9993150691463523\n", + "Test Loss: 0.45857498 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 188\n", + "Train Loss: 0.04335543119080671 Train Accuracy: 0.9872467313727288\n", + "Test Loss: 0.4854505 Test Accuracy: 0.91081214\n", + "\n", + "Epoch Number: 189\n", + "Train Loss: 0.0711721541576904 Train Accuracy: 0.9769863102534045\n", + "Test Loss: 0.38205332 Test Accuracy: 0.92575985\n", + "\n", + "Epoch Number: 190\n", + "Train Loss: 0.028694037625515093 Train Accuracy: 0.9897260347457781\n", + "Test Loss: 0.3982458 Test Accuracy: 0.9217738\n", + "\n", + "Epoch Number: 191\n", + "Train Loss: 0.020338214381298125 Train Accuracy: 0.9945205531708182\n", + "Test Loss: 0.39259014 Test Accuracy: 0.9317389\n", + "\n", + "Epoch Number: 192\n", + "Train Loss: 0.015414321870058265 Train Accuracy: 0.995753428707384\n", + "Test Loss: 0.44486946 Test Accuracy: 0.9247633\n", + "\n", + "Epoch Number: 193\n", + "Train Loss: 0.018497141398655326 Train Accuracy: 0.9928767191220637\n", + "Test Loss: 0.39597595 Test Accuracy: 0.9242651\n", + "\n", + "Epoch Number: 194\n", + "Train Loss: 0.015433995104203485 Train Accuracy: 0.9949315116830069\n", + "Test Loss: 0.38098228 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 195\n", + "Train Loss: 0.006958263319712101 Train Accuracy: 0.9980821936097863\n", + "Test Loss: 0.3999225 Test Accuracy: 0.93273544\n", + "\n", + "Epoch Number: 196\n", + "Train Loss: 0.005096633062213149 Train Accuracy: 0.9991780829756227\n", + "Test Loss: 0.37073624 Test Accuracy: 0.937718\n", + "\n", + "Epoch Number: 197\n", + "Train Loss: 0.00343738832432866 Train Accuracy: 0.9995890414878114\n", + "Test Loss: 0.37505808 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 198\n", + "Train Loss: 0.0018556024091818996 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.37871873 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 199\n", + "Train Loss: 0.0015186609035319559 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3814098 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 200\n", + "Train Loss: 0.0010581495521800619 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38004857 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 201\n", + "Train Loss: 0.0009942382828440925 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.37956578 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 202\n", + "Train Loss: 0.0009491439842561592 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.37938344 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 203\n", + "Train Loss: 0.0009144399371333038 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.37940732 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 204\n", + "Train Loss: 0.0008879157705376027 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.37957668 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 205\n", + "Train Loss: 0.0008668735051648819 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3798471 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 206\n", + "Train Loss: 0.0008491342837447046 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38018626 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 207\n", + "Train Loss: 0.0008333472782994735 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38057283 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 208\n", + "Train Loss: 0.0008187590400953076 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38099325 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 209\n", + "Train Loss: 0.0008049521249183135 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3814406 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 210\n", + "Train Loss: 0.0007916802561551664 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38191074 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 211\n", + "Train Loss: 0.0007788011856450482 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.38240153 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 212\n", + "Train Loss: 0.0007662154095843817 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.3829116 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 213\n", + "Train Loss: 0.0007538625193849104 Train Accuracy: 1.0\n", + "Test Loss: 0.38343957 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 214\n", + "Train Loss: 0.0007416940372347934 Train Accuracy: 1.0\n", + "Test Loss: 0.38398492 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 215\n", + "Train Loss: 0.0007296792061661357 Train Accuracy: 1.0\n", + "Test Loss: 0.3845468 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 216\n", + "Train Loss: 0.0007177960753125101 Train Accuracy: 1.0\n", + "Test Loss: 0.38512433 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 217\n", + "Train Loss: 0.0007060262115834893 Train Accuracy: 1.0\n", + "Test Loss: 0.3857176 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 218\n", + "Train Loss: 0.0006943566685587117 Train Accuracy: 1.0\n", + "Test Loss: 0.38632545 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 219\n", + "Train Loss: 0.0006827760203932859 Train Accuracy: 1.0\n", + "Test Loss: 0.38694763 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 220\n", + "Train Loss: 0.0006712808288483918 Train Accuracy: 1.0\n", + "Test Loss: 0.38758373 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 221\n", + "Train Loss: 0.0006598622694831622 Train Accuracy: 1.0\n", + "Test Loss: 0.38823336 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 222\n", + "Train Loss: 0.0006485197959030812 Train Accuracy: 1.0\n", + "Test Loss: 0.3888961 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 223\n", + "Train Loss: 0.0006372484367353561 Train Accuracy: 1.0\n", + "Test Loss: 0.38957155 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 224\n", + "Train Loss: 0.0006260495690297183 Train Accuracy: 1.0\n", + "Test Loss: 0.3902598 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 225\n", + "Train Loss: 0.0006149213456896402 Train Accuracy: 1.0\n", + "Test Loss: 0.39096028 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 226\n", + "Train Loss: 0.0006038654383913014 Train Accuracy: 1.0\n", + "Test Loss: 0.39167294 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 227\n", + "Train Loss: 0.0005928805896897532 Train Accuracy: 1.0\n", + "Test Loss: 0.3923978 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 228\n", + "Train Loss: 0.0005819705751093028 Train Accuracy: 1.0\n", + "Test Loss: 0.3931346 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 229\n", + "Train Loss: 0.0005711371570264232 Train Accuracy: 1.0\n", + "Test Loss: 0.39388332 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 230\n", + "Train Loss: 0.0005603803631495556 Train Accuracy: 1.0\n", + "Test Loss: 0.39464444 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 231\n", + "Train Loss: 0.0005497033376890962 Train Accuracy: 1.0\n", + "Test Loss: 0.3954174 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 232\n", + "Train Loss: 0.0005391083086828051 Train Accuracy: 1.0\n", + "Test Loss: 0.3962025 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 233\n", + "Train Loss: 0.0005285999931439706 Train Accuracy: 1.0\n", + "Test Loss: 0.39699972 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 234\n", + "Train Loss: 0.0005181774882558249 Train Accuracy: 1.0\n", + "Test Loss: 0.39780933 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 235\n", + "Train Loss: 0.0005078478582755165 Train Accuracy: 1.0\n", + "Test Loss: 0.39863133 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 236\n", + "Train Loss: 0.000497609365046541 Train Accuracy: 1.0\n", + "Test Loss: 0.39946586 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 237\n", + "Train Loss: 0.0004874655870443261 Train Accuracy: 1.0\n", + "Test Loss: 0.40031332 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 238\n", + "Train Loss: 0.0004774212549370395 Train Accuracy: 1.0\n", + "Test Loss: 0.40117365 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 239\n", + "Train Loss: 0.0004674753974341749 Train Accuracy: 1.0\n", + "Test Loss: 0.40204704 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 240\n", + "Train Loss: 0.0004576308943198879 Train Accuracy: 1.0\n", + "Test Loss: 0.40293333 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 241\n", + "Train Loss: 0.0004478914650039084 Train Accuracy: 1.0\n", + "Test Loss: 0.40383312 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 242\n", + "Train Loss: 0.00043825685871487534 Train Accuracy: 1.0\n", + "Test Loss: 0.40474612 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 243\n", + "Train Loss: 0.0004287295363211928 Train Accuracy: 1.0\n", + "Test Loss: 0.40567285 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 244\n", + "Train Loss: 0.00041931199118389734 Train Accuracy: 1.0\n", + "Test Loss: 0.4066128 Test Accuracy: 0.9337319\n", + "\n", + "Epoch Number: 245\n", + "Train Loss: 0.0004100034349081298 Train Accuracy: 1.0\n", + "Test Loss: 0.40756655 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 246\n", + "Train Loss: 0.000400808029740328 Train Accuracy: 1.0\n", + "Test Loss: 0.408534 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 247\n", + "Train Loss: 0.00039172325889685205 Train Accuracy: 1.0\n", + "Test Loss: 0.40951535 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 248\n", + "Train Loss: 0.00038275338309874424 Train Accuracy: 1.0\n", + "Test Loss: 0.41051057 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 249\n", + "Train Loss: 0.00037389971720124915 Train Accuracy: 1.0\n", + "Test Loss: 0.4115201 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 250\n", + "Train Loss: 0.0003651603550031424 Train Accuracy: 1.0\n", + "Test Loss: 0.4125443 Test Accuracy: 0.9352267\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Epoch Number: 251\n", + "Train Loss: 0.0003565376773372943 Train Accuracy: 1.0\n", + "Test Loss: 0.4135833 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 252\n", + "Train Loss: 0.0003480334934253845 Train Accuracy: 1.0\n", + "Test Loss: 0.41463742 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 253\n", + "Train Loss: 0.0003396494167359316 Train Accuracy: 1.0\n", + "Test Loss: 0.41570726 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 254\n", + "Train Loss: 0.0003313843925540935 Train Accuracy: 1.0\n", + "Test Loss: 0.41679344 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 255\n", + "Train Loss: 0.00032324064602718165 Train Accuracy: 1.0\n", + "Test Loss: 0.4178963 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 256\n", + "Train Loss: 0.00031521800582134485 Train Accuracy: 1.0\n", + "Test Loss: 0.4190162 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 257\n", + "Train Loss: 0.0003073199977859064 Train Accuracy: 1.0\n", + "Test Loss: 0.42015436 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 258\n", + "Train Loss: 0.00029954845147535896 Train Accuracy: 1.0\n", + "Test Loss: 0.42131123 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 259\n", + "Train Loss: 0.00029190532083316924 Train Accuracy: 1.0\n", + "Test Loss: 0.42248726 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 260\n", + "Train Loss: 0.0002843899977686879 Train Accuracy: 1.0\n", + "Test Loss: 0.42368302 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 261\n", + "Train Loss: 0.0002770060498932415 Train Accuracy: 1.0\n", + "Test Loss: 0.42489874 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 262\n", + "Train Loss: 0.00026975655114941605 Train Accuracy: 1.0\n", + "Test Loss: 0.4261356 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 263\n", + "Train Loss: 0.0002626420347752011 Train Accuracy: 1.0\n", + "Test Loss: 0.4273931 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 264\n", + "Train Loss: 0.0002556652377267075 Train Accuracy: 1.0\n", + "Test Loss: 0.4286726 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 265\n", + "Train Loss: 0.0002488270239491488 Train Accuracy: 1.0\n", + "Test Loss: 0.42997336 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 266\n", + "Train Loss: 0.00024212878793462064 Train Accuracy: 1.0\n", + "Test Loss: 0.4312961 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 267\n", + "Train Loss: 0.00023557050247303505 Train Accuracy: 1.0\n", + "Test Loss: 0.43264046 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 268\n", + "Train Loss: 0.000229153466771344 Train Accuracy: 1.0\n", + "Test Loss: 0.43400633 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 269\n", + "Train Loss: 0.00022287752135650396 Train Accuracy: 1.0\n", + "Test Loss: 0.43539396 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 270\n", + "Train Loss: 0.00021674406400974203 Train Accuracy: 1.0\n", + "Test Loss: 0.43680328 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 271\n", + "Train Loss: 0.0002107478593307075 Train Accuracy: 1.0\n", + "Test Loss: 0.43823338 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 272\n", + "Train Loss: 0.00020489141274058605 Train Accuracy: 1.0\n", + "Test Loss: 0.43968424 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 273\n", + "Train Loss: 0.00019917354314214844 Train Accuracy: 1.0\n", + "Test Loss: 0.4411549 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 274\n", + "Train Loss: 0.00019359330177045594 Train Accuracy: 1.0\n", + "Test Loss: 0.44264516 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 275\n", + "Train Loss: 0.00018814832334126 Train Accuracy: 1.0\n", + "Test Loss: 0.4441542 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 276\n", + "Train Loss: 0.00018283885997147554 Train Accuracy: 1.0\n", + "Test Loss: 0.4456814 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 277\n", + "Train Loss: 0.00017766496358951234 Train Accuracy: 1.0\n", + "Test Loss: 0.4472254 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 278\n", + "Train Loss: 0.00017262536239784771 Train Accuracy: 1.0\n", + "Test Loss: 0.44878626 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 279\n", + "Train Loss: 0.00016772304865896818 Train Accuracy: 1.0\n", + "Test Loss: 0.4503633 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 280\n", + "Train Loss: 0.00016296197600863645 Train Accuracy: 1.0\n", + "Test Loss: 0.451956 Test Accuracy: 0.93721974\n", + "\n", + "Epoch Number: 281\n", + "Train Loss: 0.00015834992364158995 Train Accuracy: 1.0\n", + "Test Loss: 0.45356566 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 282\n", + "Train Loss: 0.00015390685763272253 Train Accuracy: 1.0\n", + "Test Loss: 0.45519423 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 283\n", + "Train Loss: 0.0001496767374241967 Train Accuracy: 1.0\n", + "Test Loss: 0.4568483 Test Accuracy: 0.9367215\n", + "\n", + "Epoch Number: 284\n", + "Train Loss: 0.00014578019630017192 Train Accuracy: 1.0\n", + "Test Loss: 0.45854405 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 285\n", + "Train Loss: 0.00014262170305001957 Train Accuracy: 1.0\n", + "Test Loss: 0.46033803 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 286\n", + "Train Loss: 0.00014241321601630636 Train Accuracy: 1.0\n", + "Test Loss: 0.46250543 Test Accuracy: 0.9362232\n", + "\n", + "Epoch Number: 287\n", + "Train Loss: 0.00039112994681377196 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.47080484 Test Accuracy: 0.937718\n", + "\n", + "Epoch Number: 288\n", + "Train Loss: 0.002160546216233243 Train Accuracy: 0.9993150691463523\n", + "Test Loss: 0.4985566 Test Accuracy: 0.9332337\n", + "\n", + "Epoch Number: 289\n", + "Train Loss: 0.0015827084215531725 Train Accuracy: 0.9997260276585409\n", + "Test Loss: 0.460808 Test Accuracy: 0.9342302\n", + "\n", + "Epoch Number: 290\n", + "Train Loss: 0.0013418768471824914 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.4653932 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 291\n", + "Train Loss: 0.00042524092328078463 Train Accuracy: 0.9998630138292705\n", + "Test Loss: 0.46051893 Test Accuracy: 0.93472844\n", + "\n", + "Epoch Number: 292\n", + "Train Loss: 0.0002157161274326053 Train Accuracy: 1.0\n", + "Test Loss: 0.45942155 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 293\n", + "Train Loss: 0.00018626744945135688 Train Accuracy: 1.0\n", + "Test Loss: 0.45938873 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 294\n", + "Train Loss: 0.000173948956017577 Train Accuracy: 1.0\n", + "Test Loss: 0.4597626 Test Accuracy: 0.935725\n", + "\n", + "Epoch Number: 295\n", + "Train Loss: 0.00016522929556779436 Train Accuracy: 1.0\n", + "Test Loss: 0.4602842 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 296\n", + "Train Loss: 0.00015823588222552295 Train Accuracy: 1.0\n", + "Test Loss: 0.46087208 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 297\n", + "Train Loss: 0.00015232920922655517 Train Accuracy: 1.0\n", + "Test Loss: 0.4614972 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 298\n", + "Train Loss: 0.00014718409047792156 Train Accuracy: 1.0\n", + "Test Loss: 0.46214733 Test Accuracy: 0.9352267\n", + "\n", + "Epoch Number: 299\n", + "Train Loss: 0.00014260701286667889 Train Accuracy: 1.0\n", + "Test Loss: 0.46281824 Test Accuracy: 0.9352267\n", + "\n", + "Maximum Test accuracy at compressed model size(including early stopping): 0.9412058 at Epoch: 146\n", + "Final Test Accuracy: 0.9352267\n", + "\n", + "\n", + "Non-Zeros: 1932 Model Size: 7.546875 KB hasSparse: False\n", + "\n", + "The Model Directory: usps10\\FastGRNNResults/23_51_17_15_03_19\n", + "\n" + ] + } + ], + "source": [ + "FastCellTrainer.train(batchSize, totalEpochs, sess, Xtrain, Xtest,\n", + " Ytrain, Ytest, decayStep, decayRate, dataDir, currDir)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Quantization\n", + "\n", + "Byte Quantization for the trained FastModels, to reduce the model size by 4x. If one uses piece-wise linear approximations for non-linearities like quantTanh for tanh and quantSigm for Sigmoid, they can benefit greatly from pure integer arithmetic after model quantization during prediction" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Bg.npy has max: 4.9833384 min: -0.6077357\n", + "Bh.npy has max: 2.8973198 min: -0.16004847\n", + "FC.npy has max: 4.9540076 min: -5.963999\n", + "FCbias.npy has max: 2.540496 min: -1.7358814\n", + "U.npy has max: 2.2965062 min: -2.670992\n", + "W.npy has max: 1.3919494 min: -1.2454427\n", + "\n", + "\n", + "Quantized Model Dir: usps10\\FastGRNNResults/23_51_17_15_03_19\\QuantizedFastModel\n" + ] + } + ], + "source": [ + "#Model quantization\n", + "model_dir = currDir #you will see model dir printed at the end of trianing, use that here or use the currDir\n", + "\n", + "import quantizeFastModels\n", + "quantizeFastModels.quantizeFastModels(model_dir)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tf2.0/examples/FastCells/fastcell_example.py b/tf2.0/examples/FastCells/fastcell_example.py new file mode 100644 index 000000000..1d5468101 --- /dev/null +++ b/tf2.0/examples/FastCells/fastcell_example.py @@ -0,0 +1,99 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import tensorflow as tf +import numpy as np +import sys + +from edgeml.trainer.fastTrainer import FastTrainer +from edgeml.graph.rnn import FastGRNNCell +from edgeml.graph.rnn import FastRNNCell +from edgeml.graph.rnn import UGRNNLRCell +from edgeml.graph.rnn import GRULRCell +from edgeml.graph.rnn import LSTMLRCell + +tf.compat.v1.disable_eager_execution() + +def main(): + # Fixing seeds for reproducibility + tf.compat.v1.set_random_seed(42) + np.random.seed(42) + + # Hyper Param pre-processing + args = helpermethods.getArgs() + + dataDir = args.data_dir + cell = args.cell + inputDims = args.input_dim + hiddenDims = args.hidden_dim + + totalEpochs = args.epochs + learningRate = args.learning_rate + outFile = args.output_file + batchSize = args.batch_size + decayStep = args.decay_step + decayRate = args.decay_rate + + wRank = args.wRank + uRank = args.uRank + + sW = args.sW + sU = args.sU + + update_non_linearity = args.update_nl + gate_non_linearity = args.gate_nl + + (dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, + mean, std) = helpermethods.preProcessData(dataDir) + + assert dataDimension % inputDims == 0, "Infeasible per step input, " + \ + "Timesteps have to be integer" + + X = tf.compat.v1.placeholder( + "float", [None, int(dataDimension / inputDims), inputDims]) + Y = tf.compat.v1.placeholder("float", [None, numClasses]) + + currDir = helpermethods.createTimeStampDir(dataDir, cell) + + helpermethods.dumpCommand(sys.argv, currDir) + helpermethods.saveMeanStd(mean, std, currDir) + + if cell == "FastGRNN": + FastCell = FastGRNNCell(hiddenDims, + gate_non_linearity=gate_non_linearity, + update_non_linearity=update_non_linearity, + wRank=wRank, uRank=uRank) + elif cell == "FastRNN": + FastCell = FastRNNCell(hiddenDims, + update_non_linearity=update_non_linearity, + wRank=wRank, uRank=uRank) + elif cell == "UGRNN": + FastCell = UGRNNLRCell(hiddenDims, + update_non_linearity=update_non_linearity, + wRank=wRank, uRank=uRank) + elif cell == "GRU": + FastCell = GRULRCell(hiddenDims, + update_non_linearity=update_non_linearity, + wRank=wRank, uRank=uRank) + elif cell == "LSTM": + FastCell = LSTMLRCell(hiddenDims, + update_non_linearity=update_non_linearity, + wRank=wRank, uRank=uRank) + else: + sys.exit('Exiting: No Such Cell as ' + cell) + + FastCellTrainer = FastTrainer( + FastCell, X, Y, sW=sW, sU=sU, + learningRate=learningRate, outFile=outFile) + + sess = tf.compat.v1.InteractiveSession() + sess.run(tf.compat.v1.global_variables_initializer()) + + FastCellTrainer.train(batchSize, totalEpochs, sess, Xtrain, Xtest, + Ytrain, Ytest, decayStep, decayRate, + dataDir, currDir) + + +if __name__ == '__main__': + main() diff --git a/tf2.0/examples/FastCells/fetch_usps.py b/tf2.0/examples/FastCells/fetch_usps.py new file mode 100644 index 000000000..a5c314369 --- /dev/null +++ b/tf2.0/examples/FastCells/fetch_usps.py @@ -0,0 +1,66 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Setting up the USPS Data. + +import bz2 +import os +import subprocess +import sys + +import requests +import numpy as np +from sklearn.datasets import load_svmlight_file +from helpermethods import download_file, decompress + + + +def downloadData(workingDir, downloadDir, linkTrain, linkTest): + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + try: + os.makedirs(path, exist_ok=True) + except OSError: + print("Could not create %s. Make sure the path does" % path) + print("not already exist and you have permissions to create it.") + return False + + training_data_bz2 = download_file(linkTrain, path) + test_data_bz2 = download_file(linkTest, path) + + training_data = decompress(training_data_bz2) + test_data = decompress(test_data_bz2) + + train = os.path.join(path, "train.txt") + test = os.path.join(path, "test.txt") + if os.path.isfile(train): + os.remove(train) + if os.path.isfile(test): + os.remove(test) + + os.rename(training_data, train) + os.rename(test_data, test) + os.remove(training_data_bz2) + os.remove(test_data_bz2) + return True + +if __name__ == '__main__': + workingDir = './' + downloadDir = 'usps10' + linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2' + linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2' + failureMsg = ''' +Download Failed! +To manually perform the download +\t1. Create a new empty directory named `usps10`. +\t2. Download the data from the following links into the usps10 directory. +\t\tTest: %s +\t\tTrain: %s +\t3. Extract the downloaded files. +\t4. Rename `usps` to `train.txt` and, +\t5. Rename `usps.t` to `test.txt +''' % (linkTrain, linkTest) + + if not downloadData(workingDir, downloadDir, linkTrain, linkTest): + exit(failureMsg) + print("Done: see ", downloadDir) diff --git a/tf2.0/examples/FastCells/helpermethods.py b/tf2.0/examples/FastCells/helpermethods.py new file mode 100644 index 000000000..a052330f3 --- /dev/null +++ b/tf2.0/examples/FastCells/helpermethods.py @@ -0,0 +1,273 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +''' + Functions to check sanity of input arguments + for the example script. +''' +import argparse +import bz2 +import datetime +import json +import os + +import numpy as np +import requests + + +def decompress(filepath): + print("extracting: ", filepath) + zipfile = bz2.BZ2File(filepath) # open the file + data = zipfile.read() # get the decompressed data + newfilepath = os.path.splitext(filepath)[0] # assuming the filepath ends with .bz2 + with open(newfilepath, 'wb') as f: + f.write(data) # write a uncompressed file + return newfilepath + + +def download_file(url, local_folder=None): + """Downloads file pointed to by `url`. + If `local_folder` is not supplied, downloads to the current folder. + """ + filename = os.path.basename(url) + if local_folder: + filename = os.path.join(local_folder, filename) + + # Download the file + print("Downloading: " + url) + response = requests.get(url, stream=True) + if response.status_code != 200: + raise Exception("download file failed with status code: %d, fetching url '%s'" % (response.status_code, url)) + + # Write the file to disk + with open(filename, "wb") as handle: + handle.write(response.content) + return filename + + +def checkIntPos(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive int value" % value) + return ivalue + + +def checkIntNneg(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg int value" % value) + return ivalue + + +def checkFloatNneg(value): + fvalue = float(value) + if fvalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg float value" % value) + return fvalue + + +def checkFloatPos(value): + fvalue = float(value) + if fvalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive float value" % value) + return fvalue + + +def getArgs(): + ''' + Function to parse arguments for FastCells + ''' + parser = argparse.ArgumentParser( + description='HyperParams for Fast(G)RNN') + parser.add_argument('-dir', '--data-dir', required=True, + help='Data directory containing' + + 'train.npy and test.npy') + + parser.add_argument('-c', '--cell', type=str, default="FastGRNN", + help='Choose between [FastGRNN, FastRNN, UGRNN' + + ', GRU, LSTM], default: FastGRNN') + + parser.add_argument('-id', '--input-dim', type=checkIntNneg, required=True, + help='Input Dimension of RNN, each timestep will ' + + 'feed input-dim features to RNN. ' + + 'Total Feature length = Input Dim * Total Timestep') + parser.add_argument('-hd', '--hidden-dim', type=checkIntNneg, + required=True, help='Hidden Dimension of RNN') + + parser.add_argument('-e', '--epochs', type=checkIntPos, default=300, + help='Total Epochs (default: 300 try:[100, 150, 600])') + parser.add_argument('-b', '--batch-size', type=checkIntPos, default=100, + help='Batch Size to be used (default: 100)') + parser.add_argument('-lr', '--learning-rate', type=checkFloatPos, + default=0.01, help='Initial Learning rate for ' + + 'Adam Optimizer (default: 0.01)') + + parser.add_argument('-rW', '--wRank', type=checkIntPos, default=None, + help='Rank for the low-rank parameterisation of W, ' + + 'None => Full Rank') + parser.add_argument('-rU', '--uRank', type=checkIntPos, default=None, + help='Rank for the low-rank parameterisation of U, ' + + 'None => Full Rank') + + parser.add_argument('-sW', type=checkFloatPos, default=1.0, + help='Sparsity for predictor parameter W(and both ' + + 'W1 and W2 in low-rank) ' + + '(default: 1.0(Dense) try: [0.1, 0.2, 0.3])') + parser.add_argument('-sU', type=checkFloatPos, default=1.0, + help='Sparsity for predictor parameter U(and both ' + + 'U1 and U2 in low-rank) ' + + '(default: 1.0(Dense) try: [0.1, 0.2, 0.3])') + + parser.add_argument('-unl', '--update-nl', type=str, default="tanh", + help='Update non linearity. Choose between ' + + '[tanh, sigmoid, relu, quantTanh, quantSigm]. ' + + 'default => tanh. Can add more in edgeml/graph/rnn.py') + parser.add_argument('-gnl', '--gate-nl', type=str, default="sigmoid", + help='Gate non linearity. Choose between ' + + '[tanh, sigmoid, relu, quantTanh, quantSigm]. ' + + 'default => sigmoid. Can add more in ' + + 'edgeml/graph/rnn.py. Only Applicable to FastGRNN') + + parser.add_argument('-dS', '--decay-step', type=checkIntPos, default=200, + help='The interval (in epochs) after which the ' + + 'learning rate should decay. ' + + 'Default is 200 for 300 epochs') + + parser.add_argument('-dR', '--decay-rate', type=checkFloatPos, default=0.1, + help='The factor by which learning rate ' + + 'should decay after each interval. Default 0.1') + + parser.add_argument('-oF', '--output-file', default=None, + help='Output file for dumping the program output, ' + + '(default: stdout)') + + return parser.parse_args() + + +def getQuantArgs(): + ''' + Function to parse arguments for Model Quantisation + ''' + parser = argparse.ArgumentParser( + description='Arguments for quantizing Fast models. ' + + 'Works only for piece-wise linear non-linearities, ' + + 'like relu, quantTanh, quantSigm (check rnn.py for the definitions)') + parser.add_argument('-dir', '--model-dir', required=True, + help='model directory containing' + + '*.npy weight files dumped from the trained model') + parser.add_argument('-m', '--max-val', type=checkIntNneg, default=127, + help='this represents the maximum possible value ' + + 'in model, essentially the byte complexity, ' + + '127=> 1 byte is default') + parser.add_argument('-s', '--scalar-scale', type=checkIntNneg, + default=1000, help='maximum granularity/decimals ' + + 'you wish to get when quantising simple sclars ' + + 'involved. Default is 1000') + + return parser.parse_args() + + +def createTimeStampDir(dataDir, cell): + ''' + Creates a Directory with timestamp as it's name + ''' + if os.path.isdir(os.path.join(dataDir, str(cell) + 'Results')) is False: + try: + os.mkdir(os.path.join(dataDir, str(cell) + 'Results')) + except OSError: + print("Creation of the directory %s failed" % + os.path.join(dataDir, str(cell) + 'Results')) + + currDir = os.path.join(str(cell) + 'Results', + datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")) + if os.path.isdir(os.path.join(dataDir, currDir)) is False: + try: + os.mkdir(os.path.join(dataDir, currDir)) + except OSError: + print("Creation of the directory %s failed" % + os.path.join(dataDir, currDir)) + else: + return (os.path.join(dataDir, currDir)) + return None + + +def preProcessData(dataDir): + ''' + Function to pre-process input data + + Expects a .npy file of form [lbl feats] for each datapoint, + feats is timesteps*inputDims, flattened across timestep dimension. + So input of 1st timestep followed by second and so on. + + Outputs train and test set datapoints + dataDimension, numClasses are inferred directly + ''' + train = np.load(os.path.join(dataDir, 'train.npy')) + test = np.load(os.path.join(dataDir, 'test.npy')) + + dataDimension = int(train.shape[1]) - 1 + + Xtrain = train[:, 1:dataDimension + 1] + Ytrain_ = train[:, 0] + numClasses = max(Ytrain_) - min(Ytrain_) + 1 + + Xtest = test[:, 1:dataDimension + 1] + Ytest_ = test[:, 0] + + numClasses = int(max(numClasses, max(Ytest_) - min(Ytest_) + 1)) + + # Mean Var Normalisation + mean = np.mean(Xtrain, 0) + std = np.std(Xtrain, 0) + std[std[:] < 0.000001] = 1 + Xtrain = (Xtrain - mean) / std + + Xtest = (Xtest - mean) / std + # End Mean Var normalisation + + lab = Ytrain_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtrain.shape[0], numClasses)) + lab_[np.arange(Xtrain.shape[0]), lab] = 1 + Ytrain = lab_ + + lab = Ytest_.astype('uint8') + lab = np.array(lab) - min(lab) + + lab_ = np.zeros((Xtest.shape[0], numClasses)) + lab_[np.arange(Xtest.shape[0]), lab] = 1 + Ytest = lab_ + + return dataDimension, numClasses, Xtrain, Ytrain, Xtest, Ytest, mean, std + + +def dumpCommand(list, currDir): + ''' + Dumps the current command to a file for further use + ''' + commandFile = open(os.path.join(currDir, 'command.txt'), 'w') + command = "python" + + command = command + " " + ' '.join(list) + commandFile.write(command) + + commandFile.flush() + commandFile.close() + + +def saveMeanStd(mean, std, currDir): + ''' + Function to save Mean and Std vectors + ''' + np.save(os.path.join(currDir, 'mean.npy'), mean) + np.save(os.path.join(currDir, 'std.npy'), std) + + +def saveJSon(data, filename): + with open(filename, "w") as outfile: + json.dump(data, outfile, indent=2) diff --git a/tf2.0/examples/FastCells/process_usps.py b/tf2.0/examples/FastCells/process_usps.py new file mode 100644 index 000000000..7ff763b00 --- /dev/null +++ b/tf2.0/examples/FastCells/process_usps.py @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Processing the USPS Data. It is assumed that the data is already +# downloaded. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def processData(workingDir, downloadDir): + def loadLibSVMFile(file): + data = load_svmlight_file(file) + features = data[0] + labels = data[1] + retMat = np.zeros([features.shape[0], features.shape[1] + 1]) + retMat[:, 0] = labels + retMat[:, 1:] = features.todense() + return retMat + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + trf = path + '/train.txt' + tsf = path + '/test.txt' + assert os.path.isfile(trf), 'File not found: %s' % trf + assert os.path.isfile(tsf), 'File not found: %s' % tsf + train = loadLibSVMFile(trf) + test = loadLibSVMFile(tsf) + np.save(path + '/train.npy', train) + np.save(path + '/test.npy', test) + +if __name__ == '__main__': + # Configuration + workingDir = './' + downloadDir = 'usps10' + # End config + print("Processing data") + processData(workingDir, downloadDir) + print("Done") diff --git a/tf2.0/examples/FastCells/quantizeFastModels.py b/tf2.0/examples/FastCells/quantizeFastModels.py new file mode 100644 index 000000000..746f6f9f4 --- /dev/null +++ b/tf2.0/examples/FastCells/quantizeFastModels.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +import helpermethods +import os +import numpy as np + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +def min_max(A, name): + print(name + " has max: " + str(np.max(A)) + " min: " + str(np.min(A))) + return np.max([np.abs(np.max(A)), np.abs(np.min(A))]) + + +def quantizeFastModels(modelDir, maxValue=127, scalarScaleFactor=1000): + ls = os.listdir(modelDir) + paramNameList = [] + paramWeightList = [] + paramLimitList = [] + + classifierNameList = [] + classifierWeightList = [] + classifierLimitList = [] + + scalarNameList = [] + scalarWeightList = [] + + for file in ls: + if file.endswith("npy"): + if file.startswith("W"): + paramNameList.append(file) + temp = np.load(os.path.join(modelDir, file)) + paramWeightList.append(temp) + paramLimitList.append(min_max(temp, file)) + elif file.startswith("U"): + paramNameList.append(file) + temp = np.load(os.path.join(modelDir, file)) + paramWeightList.append(temp) + paramLimitList.append(min_max(temp, file)) + elif file.startswith("B"): + paramNameList.append(file) + temp = np.load(os.path.join(modelDir, file)) + paramWeightList.append(temp) + paramLimitList.append(min_max(temp, file)) + elif file.startswith("FC"): + classifierNameList.append(file) + temp = np.load(os.path.join(modelDir, file)) + classifierWeightList.append(temp) + classifierLimitList.append(min_max(temp, file)) + elif file.startswith("mean") or file.startswith("std"): + continue + else: + scalarNameList.append(file) + scalarWeightList.append(np.load(os.path.join(modelDir, file))) + + paramLimit = np.max(paramLimitList) + classifierLimit = np.max(classifierLimitList) + + paramScaleFactor = np.round((2.0 * maxValue + 1.0) / (2.0 * paramLimit)) + classifierScaleFactor = (2.0 * maxValue + 1.0) / (2.0 * classifierLimit) + + quantParamWeights = [] + for param in paramWeightList: + temp = np.round(paramScaleFactor * param) + temp[temp[:] > maxValue] = maxValue + temp[temp[:] < -maxValue] = -1 * (maxValue + 1) + + if maxValue <= 127: + temp = temp.astype('int8') + elif maxValue <= 32767: + temp = temp.astype('int16') + else: + temp = temp.astype('int32') + + quantParamWeights.append(temp) + + quantClassifierWeights = [] + for param in classifierWeightList: + temp = np.round(classifierScaleFactor * param) + temp[temp[:] > maxValue] = maxValue + temp[temp[:] < -maxValue] = -1 * (maxValue + 1) + + if maxValue <= 127: + temp = temp.astype('int8') + elif maxValue <= 32767: + temp = temp.astype('int16') + else: + temp = temp.astype('int32') + + quantClassifierWeights.append(temp) + + quantScalarWeights = [] + for scalar in scalarWeightList: + quantScalarWeights.append( + np.round(scalarScaleFactor * sigmoid(scalar)).astype('int32')) + + quantModelDir = os.path.join(modelDir, 'QuantizedFastModel') + if not os.path.isdir(quantModelDir): + try: + os.makedirs(quantModelDir, exist_ok=True) + except OSError: + print("Creation of the directory %s failed" % quantModelDir) + + np.save(os.path.join(quantModelDir, "paramScaleFactor.npy"), + paramScaleFactor.astype('int32')) + np.save(os.path.join(quantModelDir, "classifierScaleFactor.npy"), + classifierScaleFactor) + np.save(os.path.join(quantModelDir, "scalarScaleFactor"), scalarScaleFactor) + + for i in range(0, len(scalarNameList)): + np.save(os.path.join(quantModelDir, "q" + + scalarNameList[i]), quantScalarWeights[i]) + + for i in range(len(classifierNameList)): + np.save(os.path.join(quantModelDir, "q" + + classifierNameList[i]), quantClassifierWeights[i]) + + for i in range(len(paramNameList)): + np.save(os.path.join(quantModelDir, "q" + paramNameList[i]), + quantParamWeights[i]) + + print("\n\nQuantized Model Dir: " + quantModelDir) + + +def main(): + args = helpermethods.getQuantArgs() + quantizeFastModels(args.model_dir, int( + args.max_val), int(args.scalar_scale)) + + +if __name__ == '__main__': + main() diff --git a/tf2.0/examples/ProtoNN/README.md b/tf2.0/examples/ProtoNN/README.md new file mode 100644 index 000000000..d0137ac4e --- /dev/null +++ b/tf2.0/examples/ProtoNN/README.md @@ -0,0 +1,54 @@ +# Tensorflow ProtoNN Examples + +This directory includes an example [notebook](protoNN_example.ipynb) and a +command line execution script of ProtoNN developed as part of EdgeML. The +example is based on the USPS dataset. + +`edgeml.graph.protoNN` implements the ProtoNN prediction graph in Tensorflow. +The training routine for ProtoNN is decoupled from the forward graph to +facilitate a plug and play behaviour wherein ProtoNN can be combined with or +used as a final layer classifier for other architectures (RNNs, CNNs). The +training routine is implemented in `edgeml.trainer.protoNNTrainer`. + +Note that, `protoNN_example.py` assumes the data to be in a specific format. It +is assumed that train and test data is contained in two files, `train.npy` and +`test.npy`. Each containing a 2D numpy array of dimension `[numberOfExamples, +numberOfFeatures + 1]`. The first column of each matrix is assumed to contain +label information. For an N-Class problem, we assume the labels are integers +from 0 through N-1. + +**Tested With:** Tensorflow >1.6 with Python 2 and Python 3 + +## Fetching Data + +The script - [fetch_usps.py](fetch_usps.py), can be used to automatically +download and [process_usps.py](process_usps.py), can be used to process the +data into the required format. + To run this script, please use: + + python fetch_usps.py + python process_usps.py + + +## Running the ProtoNN execution script + +Along with the example notebook, a command line execution script for ProtoNN is +provided in `protoNN_example.py`. After the USPS data has been setup, this +script can be used with the following command: + +``` +python protoNN_example.py \ + --data-dir ./usps10 \ + --projection-dim 60 \ + --num-prototypes 80 \ + --gamma 0.0015 \ + --learning-rate 0.1 \ + --epochs 200 \ + --val-step 10 \ + --output-dir ./ +``` + +You can expect a test set accuracy of about 92.5%. + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT license. diff --git a/tf2.0/examples/ProtoNN/fetch_usps.py b/tf2.0/examples/ProtoNN/fetch_usps.py new file mode 100644 index 000000000..c1b2e0726 --- /dev/null +++ b/tf2.0/examples/ProtoNN/fetch_usps.py @@ -0,0 +1,64 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Setting up the USPS Data. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys + +def downloadData(workingDir, downloadDir, linkTrain, linkTest): + def runcommand(command): + p = subprocess.Popen(command.split(), stdout=subprocess.PIPE) + output, error = p.communicate() + assert(p.returncode == 0), 'Command failed: %s' % command + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + try: + os.mkdir(path) + except OSError: + print("Could not create %s. Make sure the path does" % path) + print("not already exist and you have permisions to create it.") + return False + cwd = os.getcwd() + os.chdir(path) + print("Downloading data") + command = 'wget %s' % linkTrain + runcommand(command) + command = 'wget %s' % linkTest + runcommand(command) + print("Extracting data") + command = 'bzip2 -d usps.bz2' + runcommand(command) + command = 'bzip2 -d usps.t.bz2' + runcommand(command) + command = 'mv usps train.txt' + runcommand(command) + command = 'mv usps.t test.txt' + runcommand(command) + os.chdir(cwd) + return True + +if __name__ == '__main__': + workingDir = './' + downloadDir = 'usps10' + linkTrain = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2' + linkTest = 'http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2' + failureMsg = ''' +Download Failed! +To manually perform the download +\t1. Create a new empty directory named `usps10`. +\t2. Download the data from the following links into the usps10 directory. +\t\tTest: %s +\t\tTrain: %s +\t3. Extract the downloaded files. +\t4. Rename `usps` to `train.txt` and, +\t5. Rename `usps.t` to `test.txt +''' % (linkTrain, linkTest) + + if not downloadData(workingDir, downloadDir, linkTrain, linkTest): + exit(failureMsg) + print("Done") diff --git a/tf2.0/examples/ProtoNN/helpermethods.py b/tf2.0/examples/ProtoNN/helpermethods.py new file mode 100644 index 000000000..1bd382825 --- /dev/null +++ b/tf2.0/examples/ProtoNN/helpermethods.py @@ -0,0 +1,206 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import sys +import os +import numpy as np +import tensorflow as tf +import edgeml.utils as utils +import argparse + + +def getModelSize(matrixList, sparcityList, expected=True, bytesPerVar=4): + ''' + expected: Expected size according to the parameters set. The number of + zeros could actually be more than that is required to satisfy the + sparsity constraint. + ''' + nnzList, sizeList, isSparseList = [], [], [] + hasSparse = False + for i in range(len(matrixList)): + A, s = matrixList[i], sparcityList[i] + assert A.ndim == 2 + assert s >= 0 + assert s <= 1 + nnz, size, sparse = utils.countnnZ(A, s, bytesPerVar=bytesPerVar) + nnzList.append(nnz) + sizeList.append(size) + hasSparse = (hasSparse or sparse) + + totalnnZ = np.sum(nnzList) + totalSize = np.sum(sizeList) + if expected: + return totalnnZ, totalSize, hasSparse + numNonZero = 0 + totalSize = 0 + hasSparse = False + for i in range(len(matrixList)): + A, s = matrixList[i], sparcityList[i] + numNonZero_ = np.count_nonzero(A) + numNonZero += numNonZero_ + hasSparse = (hasSparse or (s < 0.5)) + if s <= 0.5: + totalSize += numNonZero_ * 2 * bytesPerVar + else: + totalSize += A.size * bytesPerVar + return numNonZero, totalSize, hasSparse + + +def getGamma(gammaInit, projectionDim, dataDim, numPrototypes, x_train): + if gammaInit is None: + print("Using median heuristic to estimate gamma.") + gamma, W, B = utils.medianHeuristic(x_train, projectionDim, + numPrototypes) + print("Gamma estimate is: %f" % gamma) + return W, B, gamma + return None, None, gammaInit + +def to_onehot(y, numClasses, minlabel = None): + ''' + If the y labelling does not contain the minimum label info, use min-label to + provide this value. + ''' + lab = y.astype('uint8') + if minlabel is None: + minlabel = np.min(lab) + minlabel = int(minlabel) + lab = np.array(lab) - minlabel + lab_ = np.zeros((y.shape[0], numClasses)) + lab_[np.arange(y.shape[0]), lab] = 1 + return lab_ + +def preprocessData(train, test): + ''' + Loads data from the dataDir and does some initial preprocessing + steps. Data is assumed to be contained in two files, + train.npy and test.npy. Each containing a 2D numpy array of dimension + [numberOfExamples, numberOfFeatures + 1]. The first column of each + matrix is assumed to contain label information. + + For an N-Class problem, we assume the labels are integers from 0 through + N-1. + ''' + dataDimension = int(train.shape[1]) - 1 + x_train = train[:, 1:dataDimension + 1] + y_train_ = train[:, 0] + x_test = test[:, 1:dataDimension + 1] + y_test_ = test[:, 0] + + numClasses = max(y_train_) - min(y_train_) + 1 + numClasses = max(numClasses, max(y_test_) - min(y_test_) + 1) + numClasses = int(numClasses) + + # mean-var + mean = np.mean(x_train, 0) + std = np.std(x_train, 0) + std[std[:] < 0.000001] = 1 + x_train = (x_train - mean) / std + x_test = (x_test - mean) / std + + # one hot y-train + lab = y_train_.astype('uint8') + lab = np.array(lab) - min(lab) + lab_ = np.zeros((x_train.shape[0], numClasses)) + lab_[np.arange(x_train.shape[0]), lab] = 1 + y_train = lab_ + + # one hot y-test + lab = y_test_.astype('uint8') + lab = np.array(lab) - min(lab) + lab_ = np.zeros((x_test.shape[0], numClasses)) + lab_[np.arange(x_test.shape[0]), lab] = 1 + y_test = lab_ + + return dataDimension, numClasses, x_train, y_train, x_test, y_test + + + +def getProtoNNArgs(): + def checkIntPos(value): + ivalue = int(value) + if ivalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive int value" % value) + return ivalue + + def checkIntNneg(value): + ivalue = int(value) + if ivalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg int value" % value) + return ivalue + + def checkFloatNneg(value): + fvalue = float(value) + if fvalue < 0: + raise argparse.ArgumentTypeError( + "%s is an invalid non-neg float value" % value) + return fvalue + + def checkFloatPos(value): + fvalue = float(value) + if fvalue <= 0: + raise argparse.ArgumentTypeError( + "%s is an invalid positive float value" % value) + return fvalue + + ''' + Parse protoNN commandline arguments + ''' + parser = argparse.ArgumentParser( + description='Hyperparameters for ProtoNN Algorithm') + + msg = 'Data directory containing train and test data. The ' + msg += 'data is assumed to be saved as 2-D numpy matrices with ' + msg += 'names `train.npy` and `test.npy`, of dimensions\n' + msg += '\t[numberOfInstances, numberOfFeatures + 1].\n' + msg += 'The first column of each file is assumed to contain label information.' + msg += ' For a N-class problem, labels are assumed to be integers from 0 to' + msg += ' N-1 (inclusive).' + parser.add_argument('-d', '--data-dir', required=True, help=msg) + parser.add_argument('-l', '--projection-dim', type=checkIntPos, default=10, + help='Projection Dimension.') + parser.add_argument('-p', '--num-prototypes', type=checkIntPos, default=20, + help='Number of prototypes.') + parser.add_argument('-g', '--gamma', type=checkFloatPos, default=None, + help='Gamma for Gaussian kernel. If not provided, ' + + 'median heuristic will be used to estimate gamma.') + + parser.add_argument('-e', '--epochs', type=checkIntPos, default=100, + help='Total training epochs.') + parser.add_argument('-b', '--batch-size', type=checkIntPos, default=32, + help='Batch size for each pass.') + parser.add_argument('-r', '--learning-rate', type=checkFloatPos, + default=0.001, + help='Initial Learning rate for ADAM Optimizer.') + + parser.add_argument('-rW', type=float, default=0.000, + help='Coefficient for l2 regularizer for predictor' + + ' parameter W ' + '(default = 0.0).') + parser.add_argument('-rB', type=float, default=0.00, + help='Coefficient for l2 regularizer for predictor' + + ' parameter B ' + '(default = 0.0).') + parser.add_argument('-rZ', type=float, default=0.00, + help='Coefficient for l2 regularizer for predictor' + + 'parameter Z ' + + '(default = 0.0).') + + parser.add_argument('-sW', type=float, default=1.000, + help='Sparsity constraint for predictor parameter W ' + + '(default = 1.0, i.e. dense matrix).') + parser.add_argument('-sB', type=float, default=1.00, + help='Sparsity constraint for predictor parameter B ' + + '(default = 1.0, i.e. dense matrix).') + parser.add_argument('-sZ', type=float, default=1.00, + help='Sparsity constraint for predictor parameter Z ' + + '(default = 1.0, i.e. dense matrix).') + parser.add_argument('-pS', '--print-step', type=int, default=200, + help='The number of update steps between print ' + + 'calls to console.') + parser.add_argument('-vS', '--val-step', type=int, default=3, + help='The number of epochs between validation' + + 'performance evaluation') + parser.add_argument('-o', '--output-dir', type=str, default='./', + help='Output directory to dump model matrices.') + return parser.parse_args() diff --git a/tf2.0/examples/ProtoNN/process_usps.py b/tf2.0/examples/ProtoNN/process_usps.py new file mode 100644 index 000000000..dee4d1bbb --- /dev/null +++ b/tf2.0/examples/ProtoNN/process_usps.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +# +# Processing the USPS Data. It is assumed that the data is already +# downloaded. + +import subprocess +import os +import numpy as np +from sklearn.datasets import load_svmlight_file +import sys +from helpermethods import preprocessData + +def processData(workingDir, downloadDir): + def loadLibSVMFile(file): + data = load_svmlight_file(file) + features = data[0] + labels = data[1] + retMat = np.zeros([features.shape[0], features.shape[1] + 1]) + retMat[:, 0] = labels + retMat[:, 1:] = features.todense() + return retMat + + path = workingDir + '/' + downloadDir + path = os.path.abspath(path) + trf = path + '/train.txt' + tsf = path + '/test.txt' + assert os.path.isfile(trf), 'File not found: %s' % trf + assert os.path.isfile(tsf), 'File not found: %s' % tsf + train = loadLibSVMFile(trf) + test = loadLibSVMFile(tsf) + np.save(path + '/train_unnormalized.npy', train) + np.save(path + '/test_unnormalized.npy', test) + _, _, x_train, y_train, x_test, y_test = preprocessData(train, test) + + y_ = np.expand_dims(np.argmax(y_train, axis=1), axis=1) + train = np.concatenate([y_, x_train], axis=1) + np.save(path + '/train.npy', train) + y_ = np.expand_dims(np.argmax(y_test, axis=1), axis=1) + test = np.concatenate([y_, x_test], axis=1) + np.save(path + '/test.npy', test) + + +if __name__ == '__main__': + # Configuration + workingDir = './' + downloadDir = 'usps10' + # End config + print("Processing data") + processData(workingDir, downloadDir) + print("Done") diff --git a/tf2.0/examples/ProtoNN/protoNN_example.ipynb b/tf2.0/examples/ProtoNN/protoNN_example.ipynb new file mode 100644 index 000000000..9581b97e9 --- /dev/null +++ b/tf2.0/examples/ProtoNN/protoNN_example.ipynb @@ -0,0 +1,449 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ProtoNN in Tensorflow\n", + "\n", + "This is a simple notebook that illustrates the usage of Tensorflow implementation of ProtoNN. We are using the USPS dataset. Please refer to `fetch_usps.py` and `process_usps.py`for more details on downloading the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:06:10.223951Z", + "start_time": "2018-08-15T13:06:09.303454Z" + } + }, + "outputs": [], + "source": [ + "# Copyright (c) Microsoft Corporation. All rights reserved.\n", + "# Licensed under the MIT license.\n", + "\n", + "from __future__ import print_function\n", + "import sys\n", + "import os\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "\n", + "from edgeml.trainer.protoNNTrainer import ProtoNNTrainer\n", + "from edgeml.graph.protoNN import ProtoNN\n", + "import edgeml.utils as utils\n", + "import helpermethods as helper" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# USPS Data\n", + "\n", + "It is assumed that the USPS data has already been downloaded and set up with the help of [fetch_usps.py](fetch_usps.py) and is placed in the `./usps10` subdirectory." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:06:10.271026Z", + "start_time": "2018-08-15T13:06:10.225900Z" + } + }, + "outputs": [], + "source": [ + "# Load data\n", + "DATA_DIR = './usps10'\n", + "train, test = np.load(DATA_DIR + '/train.npy'), np.load(DATA_DIR + '/test.npy')\n", + "x_train, y_train = train[:, 1:], train[:, 0]\n", + "x_test, y_test = test[:, 1:], test[:, 0]\n", + "\n", + "numClasses = max(y_train) - min(y_train) + 1\n", + "numClasses = max(numClasses, max(y_test) - min(y_test) + 1)\n", + "numClasses = int(numClasses)\n", + "\n", + "y_train = helper.to_onehot(y_train, numClasses)\n", + "y_test = helper.to_onehot(y_test, numClasses)\n", + "dataDimension = x_train.shape[1]\n", + "numClasses = y_train.shape[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Parameters\n", + "\n", + "Note that ProtoNN is very sensitive to the value of the hyperparameter $\\gamma$, here stored in valiable `GAMMA`. If `GAMMA` is set to `None`, median heuristic will be used to estimate a good value of $\\gamma$ through the `helper.getGamma()` method. This method also returns the corresponding `W` and `B` matrices which should be used to initialize ProtoNN (as is done here)." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:06:10.279204Z", + "start_time": "2018-08-15T13:06:10.272880Z" + } + }, + "outputs": [], + "source": [ + "PROJECTION_DIM = 60\n", + "NUM_PROTOTYPES = 60\n", + "REG_W = 0.000005\n", + "REG_B = 0.0\n", + "REG_Z = 0.00005\n", + "SPAR_W = 0.8\n", + "SPAR_B = 1.0\n", + "SPAR_Z = 1.0\n", + "LEARNING_RATE = 0.05\n", + "NUM_EPOCHS = 200\n", + "BATCH_SIZE = 32\n", + "GAMMA = 0.0015" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:06:10.307632Z", + "start_time": "2018-08-15T13:06:10.280955Z" + } + }, + "outputs": [], + "source": [ + "W, B, gamma = helper.getGamma(GAMMA, PROJECTION_DIM, dataDimension,\n", + " NUM_PROTOTYPES, x_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:07:22.641991Z", + "start_time": "2018-08-15T13:06:10.309353Z" + }, + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 0 Batch: 0 Loss: 5.85158 Accuracy: 0.03125\n", + "Epoch: 1 Batch: 0 Loss: 1.53823 Accuracy: 0.65625\n", + "Epoch: 2 Batch: 0 Loss: 0.81371 Accuracy: 0.87500\n", + "Epoch: 3 Batch: 0 Loss: 0.51246 Accuracy: 0.87500\n", + "Epoch: 4 Batch: 0 Loss: 0.41875 Accuracy: 0.93750\n", + "Epoch: 5 Batch: 0 Loss: 0.36797 Accuracy: 0.96875\n", + "Epoch: 6 Batch: 0 Loss: 0.32868 Accuracy: 0.96875\n", + "Epoch: 7 Batch: 0 Loss: 0.30316 Accuracy: 0.96875\n", + "Epoch: 8 Batch: 0 Loss: 0.29075 Accuracy: 0.96875\n", + "Epoch: 9 Batch: 0 Loss: 0.28370 Accuracy: 0.96875\n", + "Test Loss: 0.50615 Accuracy: 0.89497\n", + "Epoch: 10 Batch: 0 Loss: 0.28014 Accuracy: 0.96875\n", + "Epoch: 11 Batch: 0 Loss: 0.27734 Accuracy: 0.96875\n", + "Epoch: 12 Batch: 0 Loss: 0.27511 Accuracy: 0.96875\n", + "Epoch: 13 Batch: 0 Loss: 0.27126 Accuracy: 0.96875\n", + "Epoch: 14 Batch: 0 Loss: 0.26776 Accuracy: 0.96875\n", + "Epoch: 15 Batch: 0 Loss: 0.26506 Accuracy: 0.96875\n", + "Epoch: 16 Batch: 0 Loss: 0.26371 Accuracy: 0.96875\n", + "Epoch: 17 Batch: 0 Loss: 0.26249 Accuracy: 0.96875\n", + "Epoch: 18 Batch: 0 Loss: 0.26094 Accuracy: 0.96875\n", + "Epoch: 19 Batch: 0 Loss: 0.25879 Accuracy: 0.96875\n", + "Test Loss: 0.54362 Accuracy: 0.89494\n", + "Epoch: 20 Batch: 0 Loss: 0.25642 Accuracy: 0.96875\n", + "Epoch: 21 Batch: 0 Loss: 0.25328 Accuracy: 0.96875\n", + "Epoch: 22 Batch: 0 Loss: 0.25015 Accuracy: 0.96875\n", + "Epoch: 23 Batch: 0 Loss: 0.24684 Accuracy: 0.96875\n", + "Epoch: 24 Batch: 0 Loss: 0.24365 Accuracy: 0.96875\n", + "Epoch: 25 Batch: 0 Loss: 0.24023 Accuracy: 0.96875\n", + "Epoch: 26 Batch: 0 Loss: 0.23747 Accuracy: 0.96875\n", + "Epoch: 27 Batch: 0 Loss: 0.23460 Accuracy: 0.96875\n", + "Epoch: 28 Batch: 0 Loss: 0.23170 Accuracy: 0.96875\n", + "Epoch: 29 Batch: 0 Loss: 0.22903 Accuracy: 0.96875\n", + "Test Loss: 0.54884 Accuracy: 0.89391\n", + "Epoch: 30 Batch: 0 Loss: 0.22662 Accuracy: 0.96875\n", + "Epoch: 31 Batch: 0 Loss: 0.22448 Accuracy: 0.96875\n", + "Epoch: 32 Batch: 0 Loss: 0.22245 Accuracy: 0.96875\n", + "Epoch: 33 Batch: 0 Loss: 0.22068 Accuracy: 0.96875\n", + "Epoch: 34 Batch: 0 Loss: 0.21904 Accuracy: 0.96875\n", + "Epoch: 35 Batch: 0 Loss: 0.21723 Accuracy: 0.96875\n", + "Epoch: 36 Batch: 0 Loss: 0.21582 Accuracy: 0.96875\n", + "Epoch: 37 Batch: 0 Loss: 0.21409 Accuracy: 0.96875\n", + "Epoch: 38 Batch: 0 Loss: 0.21246 Accuracy: 0.96875\n", + "Epoch: 39 Batch: 0 Loss: 0.21095 Accuracy: 0.96875\n", + "Test Loss: 0.52917 Accuracy: 0.90091\n", + "Epoch: 40 Batch: 0 Loss: 0.20928 Accuracy: 0.96875\n", + "Epoch: 41 Batch: 0 Loss: 0.20770 Accuracy: 0.96875\n", + "Epoch: 42 Batch: 0 Loss: 0.20633 Accuracy: 0.96875\n", + "Epoch: 43 Batch: 0 Loss: 0.20512 Accuracy: 0.96875\n", + "Epoch: 44 Batch: 0 Loss: 0.20377 Accuracy: 0.96875\n", + "Epoch: 45 Batch: 0 Loss: 0.20240 Accuracy: 0.96875\n", + "Epoch: 46 Batch: 0 Loss: 0.20124 Accuracy: 0.96875\n", + "Epoch: 47 Batch: 0 Loss: 0.20002 Accuracy: 0.96875\n", + "Epoch: 48 Batch: 0 Loss: 0.19910 Accuracy: 0.96875\n", + "Epoch: 49 Batch: 0 Loss: 0.19808 Accuracy: 0.96875\n", + "Test Loss: 0.50988 Accuracy: 0.90292\n", + "Epoch: 50 Batch: 0 Loss: 0.19705 Accuracy: 0.96875\n", + "Epoch: 51 Batch: 0 Loss: 0.19629 Accuracy: 1.00000\n", + "Epoch: 52 Batch: 0 Loss: 0.19560 Accuracy: 1.00000\n", + "Epoch: 53 Batch: 0 Loss: 0.19483 Accuracy: 1.00000\n", + "Epoch: 54 Batch: 0 Loss: 0.19404 Accuracy: 1.00000\n", + "Epoch: 55 Batch: 0 Loss: 0.19351 Accuracy: 1.00000\n", + "Epoch: 56 Batch: 0 Loss: 0.19279 Accuracy: 1.00000\n", + "Epoch: 57 Batch: 0 Loss: 0.19250 Accuracy: 1.00000\n", + "Epoch: 58 Batch: 0 Loss: 0.19207 Accuracy: 1.00000\n", + "Epoch: 59 Batch: 0 Loss: 0.19169 Accuracy: 1.00000\n", + "Test Loss: 0.48988 Accuracy: 0.90443\n", + "Epoch: 60 Batch: 0 Loss: 0.19146 Accuracy: 1.00000\n", + "Epoch: 61 Batch: 0 Loss: 0.19119 Accuracy: 1.00000\n", + "Epoch: 62 Batch: 0 Loss: 0.19095 Accuracy: 1.00000\n", + "Epoch: 63 Batch: 0 Loss: 0.19077 Accuracy: 1.00000\n", + "Epoch: 64 Batch: 0 Loss: 0.19066 Accuracy: 1.00000\n", + "Epoch: 65 Batch: 0 Loss: 0.19071 Accuracy: 1.00000\n", + "Epoch: 66 Batch: 0 Loss: 0.19066 Accuracy: 1.00000\n", + "Epoch: 67 Batch: 0 Loss: 0.19071 Accuracy: 1.00000\n", + "Epoch: 68 Batch: 0 Loss: 0.19083 Accuracy: 1.00000\n", + "Epoch: 69 Batch: 0 Loss: 0.19090 Accuracy: 1.00000\n", + "Test Loss: 0.47286 Accuracy: 0.90841\n", + "Epoch: 70 Batch: 0 Loss: 0.19108 Accuracy: 1.00000\n", + "Epoch: 71 Batch: 0 Loss: 0.19110 Accuracy: 1.00000\n", + "Epoch: 72 Batch: 0 Loss: 0.19109 Accuracy: 1.00000\n", + "Epoch: 73 Batch: 0 Loss: 0.19118 Accuracy: 1.00000\n", + "Epoch: 74 Batch: 0 Loss: 0.19115 Accuracy: 1.00000\n", + "Epoch: 75 Batch: 0 Loss: 0.19122 Accuracy: 1.00000\n", + "Epoch: 76 Batch: 0 Loss: 0.19099 Accuracy: 1.00000\n", + "Epoch: 77 Batch: 0 Loss: 0.19087 Accuracy: 1.00000\n", + "Epoch: 78 Batch: 0 Loss: 0.19070 Accuracy: 1.00000\n", + "Epoch: 79 Batch: 0 Loss: 0.19069 Accuracy: 0.96875\n", + "Test Loss: 0.46010 Accuracy: 0.91289\n", + "Epoch: 80 Batch: 0 Loss: 0.19055 Accuracy: 0.96875\n", + "Epoch: 81 Batch: 0 Loss: 0.19055 Accuracy: 0.96875\n", + "Epoch: 82 Batch: 0 Loss: 0.19013 Accuracy: 0.96875\n", + "Epoch: 83 Batch: 0 Loss: 0.19005 Accuracy: 0.96875\n", + "Epoch: 84 Batch: 0 Loss: 0.18991 Accuracy: 0.96875\n", + "Epoch: 85 Batch: 0 Loss: 0.18985 Accuracy: 0.96875\n", + "Epoch: 86 Batch: 0 Loss: 0.18961 Accuracy: 0.96875\n", + "Epoch: 87 Batch: 0 Loss: 0.18926 Accuracy: 0.96875\n", + "Epoch: 88 Batch: 0 Loss: 0.18901 Accuracy: 0.96875\n", + "Epoch: 89 Batch: 0 Loss: 0.18866 Accuracy: 0.96875\n", + "Test Loss: 0.45069 Accuracy: 0.91588\n", + "Epoch: 90 Batch: 0 Loss: 0.18821 Accuracy: 0.96875\n", + "Epoch: 91 Batch: 0 Loss: 0.18801 Accuracy: 0.96875\n", + "Epoch: 92 Batch: 0 Loss: 0.18799 Accuracy: 0.96875\n", + "Epoch: 93 Batch: 0 Loss: 0.18779 Accuracy: 0.96875\n", + "Epoch: 94 Batch: 0 Loss: 0.18743 Accuracy: 0.96875\n", + "Epoch: 95 Batch: 0 Loss: 0.18732 Accuracy: 0.96875\n", + "Epoch: 96 Batch: 0 Loss: 0.18720 Accuracy: 0.96875\n", + "Epoch: 97 Batch: 0 Loss: 0.18696 Accuracy: 0.96875\n", + "Epoch: 98 Batch: 0 Loss: 0.18674 Accuracy: 0.96875\n", + "Epoch: 99 Batch: 0 Loss: 0.18637 Accuracy: 0.96875\n", + "Test Loss: 0.44313 Accuracy: 0.91739\n", + "Epoch: 100 Batch: 0 Loss: 0.18625 Accuracy: 0.96875\n", + "Epoch: 101 Batch: 0 Loss: 0.18607 Accuracy: 0.96875\n", + "Epoch: 102 Batch: 0 Loss: 0.18596 Accuracy: 0.96875\n", + "Epoch: 103 Batch: 0 Loss: 0.18585 Accuracy: 0.96875\n", + "Epoch: 104 Batch: 0 Loss: 0.18578 Accuracy: 0.96875\n", + "Epoch: 105 Batch: 0 Loss: 0.18561 Accuracy: 0.96875\n", + "Epoch: 106 Batch: 0 Loss: 0.18538 Accuracy: 0.96875\n", + "Epoch: 107 Batch: 0 Loss: 0.18525 Accuracy: 0.96875\n", + "Epoch: 108 Batch: 0 Loss: 0.18512 Accuracy: 0.96875\n", + "Epoch: 109 Batch: 0 Loss: 0.18494 Accuracy: 0.96875\n", + "Test Loss: 0.43641 Accuracy: 0.91939\n", + "Epoch: 110 Batch: 0 Loss: 0.18505 Accuracy: 0.96875\n", + "Epoch: 111 Batch: 0 Loss: 0.18486 Accuracy: 0.96875\n", + "Epoch: 112 Batch: 0 Loss: 0.18481 Accuracy: 0.96875\n", + "Epoch: 113 Batch: 0 Loss: 0.18461 Accuracy: 0.96875\n", + "Epoch: 114 Batch: 0 Loss: 0.18439 Accuracy: 0.96875\n", + "Epoch: 115 Batch: 0 Loss: 0.18411 Accuracy: 0.96875\n", + "Epoch: 116 Batch: 0 Loss: 0.18385 Accuracy: 0.96875\n", + "Epoch: 117 Batch: 0 Loss: 0.18367 Accuracy: 0.96875\n", + "Epoch: 118 Batch: 0 Loss: 0.18353 Accuracy: 0.96875\n", + "Epoch: 119 Batch: 0 Loss: 0.18333 Accuracy: 0.96875\n", + "Test Loss: 0.43042 Accuracy: 0.92185\n", + "Epoch: 120 Batch: 0 Loss: 0.18315 Accuracy: 0.96875\n", + "Epoch: 121 Batch: 0 Loss: 0.18304 Accuracy: 0.96875\n", + "Epoch: 122 Batch: 0 Loss: 0.18268 Accuracy: 0.96875\n", + "Epoch: 123 Batch: 0 Loss: 0.18227 Accuracy: 0.96875\n", + "Epoch: 124 Batch: 0 Loss: 0.18209 Accuracy: 0.96875\n", + "Epoch: 125 Batch: 0 Loss: 0.18206 Accuracy: 0.96875\n", + "Epoch: 126 Batch: 0 Loss: 0.18184 Accuracy: 0.96875\n", + "Epoch: 127 Batch: 0 Loss: 0.18185 Accuracy: 0.96875\n", + "Epoch: 128 Batch: 0 Loss: 0.18159 Accuracy: 0.96875\n", + "Epoch: 129 Batch: 0 Loss: 0.18152 Accuracy: 0.96875\n", + "Test Loss: 0.42608 Accuracy: 0.92284\n", + "Epoch: 130 Batch: 0 Loss: 0.18125 Accuracy: 0.96875\n", + "Epoch: 131 Batch: 0 Loss: 0.18102 Accuracy: 0.96875\n", + "Epoch: 132 Batch: 0 Loss: 0.18075 Accuracy: 0.96875\n", + "Epoch: 133 Batch: 0 Loss: 0.18055 Accuracy: 0.96875\n", + "Epoch: 134 Batch: 0 Loss: 0.18025 Accuracy: 0.96875\n", + "Epoch: 135 Batch: 0 Loss: 0.18015 Accuracy: 0.96875\n", + "Epoch: 136 Batch: 0 Loss: 0.18000 Accuracy: 0.96875\n", + "Epoch: 137 Batch: 0 Loss: 0.17977 Accuracy: 0.96875\n", + "Epoch: 138 Batch: 0 Loss: 0.17963 Accuracy: 0.96875\n", + "Epoch: 139 Batch: 0 Loss: 0.17950 Accuracy: 0.96875\n", + "Test Loss: 0.42305 Accuracy: 0.92238\n", + "Epoch: 140 Batch: 0 Loss: 0.17935 Accuracy: 0.96875\n", + "Epoch: 141 Batch: 0 Loss: 0.17908 Accuracy: 0.96875\n", + "Epoch: 142 Batch: 0 Loss: 0.17903 Accuracy: 0.96875\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 143 Batch: 0 Loss: 0.17900 Accuracy: 0.96875\n", + "Epoch: 144 Batch: 0 Loss: 0.17877 Accuracy: 0.96875\n", + "Epoch: 145 Batch: 0 Loss: 0.17863 Accuracy: 0.96875\n", + "Epoch: 146 Batch: 0 Loss: 0.17844 Accuracy: 0.96875\n", + "Epoch: 147 Batch: 0 Loss: 0.17825 Accuracy: 1.00000\n", + "Epoch: 148 Batch: 0 Loss: 0.17808 Accuracy: 1.00000\n", + "Epoch: 149 Batch: 0 Loss: 0.17798 Accuracy: 1.00000\n", + "Test Loss: 0.42054 Accuracy: 0.92288\n", + "Epoch: 150 Batch: 0 Loss: 0.17788 Accuracy: 1.00000\n", + "Epoch: 151 Batch: 0 Loss: 0.17773 Accuracy: 1.00000\n", + "Epoch: 152 Batch: 0 Loss: 0.17753 Accuracy: 1.00000\n", + "Epoch: 153 Batch: 0 Loss: 0.17743 Accuracy: 1.00000\n", + "Epoch: 154 Batch: 0 Loss: 0.17736 Accuracy: 1.00000\n", + "Epoch: 155 Batch: 0 Loss: 0.17724 Accuracy: 1.00000\n", + "Epoch: 156 Batch: 0 Loss: 0.17721 Accuracy: 1.00000\n", + "Epoch: 157 Batch: 0 Loss: 0.17721 Accuracy: 1.00000\n", + "Epoch: 158 Batch: 0 Loss: 0.17713 Accuracy: 1.00000\n", + "Epoch: 159 Batch: 0 Loss: 0.17701 Accuracy: 1.00000\n", + "Test Loss: 0.41836 Accuracy: 0.92337\n", + "Epoch: 160 Batch: 0 Loss: 0.17695 Accuracy: 1.00000\n", + "Epoch: 161 Batch: 0 Loss: 0.17691 Accuracy: 1.00000\n", + "Epoch: 162 Batch: 0 Loss: 0.17687 Accuracy: 1.00000\n", + "Epoch: 163 Batch: 0 Loss: 0.17687 Accuracy: 1.00000\n", + "Epoch: 164 Batch: 0 Loss: 0.17682 Accuracy: 1.00000\n", + "Epoch: 165 Batch: 0 Loss: 0.17686 Accuracy: 1.00000\n", + "Epoch: 166 Batch: 0 Loss: 0.17686 Accuracy: 1.00000\n", + "Epoch: 167 Batch: 0 Loss: 0.17687 Accuracy: 1.00000\n", + "Epoch: 168 Batch: 0 Loss: 0.17688 Accuracy: 1.00000\n", + "Epoch: 169 Batch: 0 Loss: 0.17688 Accuracy: 1.00000\n", + "Test Loss: 0.41637 Accuracy: 0.92536\n", + "Epoch: 170 Batch: 0 Loss: 0.17685 Accuracy: 1.00000\n", + "Epoch: 171 Batch: 0 Loss: 0.17684 Accuracy: 1.00000\n", + "Epoch: 172 Batch: 0 Loss: 0.17688 Accuracy: 1.00000\n", + "Epoch: 173 Batch: 0 Loss: 0.17695 Accuracy: 1.00000\n", + "Epoch: 174 Batch: 0 Loss: 0.17693 Accuracy: 1.00000\n", + "Epoch: 175 Batch: 0 Loss: 0.17693 Accuracy: 1.00000\n", + "Epoch: 176 Batch: 0 Loss: 0.17697 Accuracy: 1.00000\n", + "Epoch: 177 Batch: 0 Loss: 0.17707 Accuracy: 1.00000\n", + "Epoch: 178 Batch: 0 Loss: 0.17715 Accuracy: 1.00000\n", + "Epoch: 179 Batch: 0 Loss: 0.17728 Accuracy: 1.00000\n", + "Test Loss: 0.41443 Accuracy: 0.92585\n", + "Epoch: 180 Batch: 0 Loss: 0.17732 Accuracy: 1.00000\n", + "Epoch: 181 Batch: 0 Loss: 0.17738 Accuracy: 1.00000\n", + "Epoch: 182 Batch: 0 Loss: 0.17747 Accuracy: 1.00000\n", + "Epoch: 183 Batch: 0 Loss: 0.17750 Accuracy: 1.00000\n", + "Epoch: 184 Batch: 0 Loss: 0.17762 Accuracy: 1.00000\n", + "Epoch: 185 Batch: 0 Loss: 0.17775 Accuracy: 1.00000\n", + "Epoch: 186 Batch: 0 Loss: 0.17791 Accuracy: 1.00000\n", + "Epoch: 187 Batch: 0 Loss: 0.17795 Accuracy: 1.00000\n", + "Epoch: 188 Batch: 0 Loss: 0.17811 Accuracy: 1.00000\n", + "Epoch: 189 Batch: 0 Loss: 0.17818 Accuracy: 1.00000\n", + "Test Loss: 0.41260 Accuracy: 0.92536\n", + "Epoch: 190 Batch: 0 Loss: 0.17835 Accuracy: 1.00000\n", + "Epoch: 191 Batch: 0 Loss: 0.17847 Accuracy: 1.00000\n", + "Epoch: 192 Batch: 0 Loss: 0.17855 Accuracy: 1.00000\n", + "Epoch: 193 Batch: 0 Loss: 0.17863 Accuracy: 1.00000\n", + "Epoch: 194 Batch: 0 Loss: 0.17880 Accuracy: 1.00000\n", + "Epoch: 195 Batch: 0 Loss: 0.17885 Accuracy: 1.00000\n", + "Epoch: 196 Batch: 0 Loss: 0.17900 Accuracy: 1.00000\n", + "Epoch: 197 Batch: 0 Loss: 0.17912 Accuracy: 1.00000\n", + "Epoch: 198 Batch: 0 Loss: 0.17927 Accuracy: 1.00000\n", + "Epoch: 199 Batch: 0 Loss: 0.17948 Accuracy: 1.00000\n", + "Test Loss: 0.41087 Accuracy: 0.92536\n" + ] + } + ], + "source": [ + "# Setup input and train protoNN\n", + "X = tf.placeholder(tf.float32, [None, dataDimension], name='X')\n", + "Y = tf.placeholder(tf.float32, [None, numClasses], name='Y')\n", + "protoNN = ProtoNN(dataDimension, PROJECTION_DIM,\n", + " NUM_PROTOTYPES, numClasses,\n", + " gamma, W=W, B=B)\n", + "trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z,\n", + " SPAR_W, SPAR_B, SPAR_Z,\n", + " LEARNING_RATE, X, Y, lossType='xentropy')\n", + "sess = tf.Session()\n", + "trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, y_train, y_test,\n", + " printStep=600, valStep=10)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Model Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2018-08-15T13:07:22.671507Z", + "start_time": "2018-08-15T13:07:22.645050Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Final test accuracy 0.92526156\n", + "Model size constraint (Bytes): 78240\n", + "Number of non-zeros: 19560\n", + "Actual model size: 78240\n", + "Actual non-zeros: 16488\n" + ] + } + ], + "source": [ + "acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test})\n", + "# W, B, Z are tensorflow graph nodes\n", + "W, B, Z, _ = protoNN.getModelMatrices()\n", + "matrixList = sess.run([W, B, Z])\n", + "sparcityList = [SPAR_W, SPAR_B, SPAR_Z]\n", + "nnz, size, sparse = helper.getModelSize(matrixList, sparcityList)\n", + "print(\"Final test accuracy\", acc)\n", + "print(\"Model size constraint (Bytes): \", size)\n", + "print(\"Number of non-zeros: \", nnz)\n", + "nnz, size, sparse = helper.getModelSize(matrixList, sparcityList, expected=False)\n", + "print(\"Actual model size: \", size)\n", + "print(\"Actual non-zeros: \", nnz)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tf2.0/examples/ProtoNN/protoNN_example.py b/tf2.0/examples/ProtoNN/protoNN_example.py new file mode 100644 index 000000000..9b49c6542 --- /dev/null +++ b/tf2.0/examples/ProtoNN/protoNN_example.py @@ -0,0 +1,88 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +from __future__ import print_function +import sys +import os +import numpy as np +import tensorflow as tf +from edgeml.trainer.protoNNTrainer import ProtoNNTrainer +from edgeml.graph.protoNN import ProtoNN +import edgeml.utils as utils +import helpermethods as helper + +tf.compat.v1.disable_eager_execution() + +def main(): + config = helper.getProtoNNArgs() + # Get hyper parameters + DATA_DIR = config.data_dir + PROJECTION_DIM = config.projection_dim + NUM_PROTOTYPES = config.num_prototypes + REG_W = config.rW + REG_B = config.rB + REG_Z = config.rZ + SPAR_W = config.sW + SPAR_B = config.sB + SPAR_Z = config.sZ + LEARNING_RATE = config.learning_rate + NUM_EPOCHS = config.epochs + BATCH_SIZE = config.batch_size + PRINT_STEP = config.print_step + VAL_STEP = config.val_step + OUT_DIR = config.output_dir + + # Load data + train = np.load(DATA_DIR + '/train.npy') + test = np.load(DATA_DIR + '/test.npy') + x_train, y_train = train[:, 1:], train[:, 0] + x_test, y_test = test[:, 1:], test[:, 0] + # Convert y to one-hot + minval = min(min(y_train), min(y_test)) + numClasses = max(y_train) - min(y_train) + 1 + numClasses = max(numClasses, max(y_test) - min(y_test) + 1) + numClasses = int(numClasses) + y_train = helper.to_onehot(y_train, numClasses, minlabel=minval) + y_test = helper.to_onehot(y_test, numClasses, minlabel=minval) + dataDimension = x_train.shape[1] + + W, B, gamma = helper.getGamma(config.gamma, PROJECTION_DIM, dataDimension, + NUM_PROTOTYPES, x_train) + + # Setup input and train protoNN + X = tf.compat.v1.placeholder(tf.float32, [None, dataDimension], name='X') + Y = tf.compat.v1.placeholder(tf.float32, [None, numClasses], name='Y') + protoNN = ProtoNN(dataDimension, PROJECTION_DIM, + NUM_PROTOTYPES, numClasses, + gamma, W=W, B=B) + trainer = ProtoNNTrainer(protoNN, REG_W, REG_B, REG_Z, + SPAR_W, SPAR_B, SPAR_Z, + LEARNING_RATE, X, Y, lossType='xentropy') + sess = tf.compat.v1.Session() + trainer.train(BATCH_SIZE, NUM_EPOCHS, sess, x_train, x_test, + y_train, y_test, printStep=PRINT_STEP, valStep=VAL_STEP) + + # Print some summary metrics + acc = sess.run(protoNN.accuracy, feed_dict={X: x_test, Y: y_test}) + # W, B, Z are tensorflow graph nodes + W, B, Z, gamma = protoNN.getModelMatrices() + matrixList = sess.run([W, B, Z]) + gamma = sess.run(gamma) + sparcityList = [SPAR_W, SPAR_B, SPAR_Z] + nnz, size, sparse = helper.getModelSize(matrixList, sparcityList) + print("Final test accuracy", acc) + print("Model size constraint (Bytes): ", size) + print("Number of non-zeros: ", nnz) + nnz, size, sparse = helper.getModelSize(matrixList, sparcityList, + expected=False) + print("Actual model size: ", size) + print("Actual non-zeros: ", nnz) + print("Saving model matrices to: ", OUT_DIR) + np.save(OUT_DIR + '/W.npy', matrixList[0]) + np.save(OUT_DIR + '/B.npy', matrixList[1]) + np.save(OUT_DIR + '/Z.npy', matrixList[2]) + np.save(OUT_DIR + '/gamma.npy', gamma) + + +if __name__ == '__main__': + main() diff --git a/tf2.0/requirements-cpu.txt b/tf2.0/requirements-cpu.txt new file mode 100644 index 000000000..25fdc1788 --- /dev/null +++ b/tf2.0/requirements-cpu.txt @@ -0,0 +1,7 @@ +jupyter==1.0.0 +numpy==1.14.5 +pandas==0.23.4 +scikit-learn==0.19.2 +scipy==1.1.0 +tensorflow==1.10.1 +requests \ No newline at end of file diff --git a/tf2.0/requirements-gpu.txt b/tf2.0/requirements-gpu.txt new file mode 100644 index 000000000..f181ccc0a --- /dev/null +++ b/tf2.0/requirements-gpu.txt @@ -0,0 +1,7 @@ +jupyter==1.0.0 +numpy==1.14.5 +pandas==0.23.4 +scikit-learn==0.19.2 +scipy==1.1.0 +tensorflow-gpu==1.10.1 +requests \ No newline at end of file diff --git a/tf2.0/setup.py b/tf2.0/setup.py new file mode 100644 index 000000000..dfb6fac46 --- /dev/null +++ b/tf2.0/setup.py @@ -0,0 +1,9 @@ +from distutils.core import setup + +setup( + name='edgeml', + version='0.2', + packages=['edgeml', ], + license='MIT License', + long_description=open('../License.txt').read(), +)