From 807ca417e20ea464eac549885fcf72237e3b47a0 Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Sat, 13 Jul 2024 17:12:17 +0800 Subject: [PATCH 01/10] add model provider - Amazon Sagemaker --- .../model_providers/sagemaker/__init__.py | 0 .../sagemaker/_assets/icon_l_en.png | Bin 0 -> 9395 bytes .../sagemaker/_assets/icon_s_en.png | Bin 0 -> 9720 bytes .../model_providers/sagemaker/llm/__init__.py | 0 .../model_providers/sagemaker/llm/llm.py | 244 ++++++++++++++++++ .../sagemaker/rerank/__init__.py | 0 .../sagemaker/rerank/rerank.py | 189 ++++++++++++++ .../model_providers/sagemaker/sagemaker.py | 19 ++ .../model_providers/sagemaker/sagemaker.yaml | 117 +++++++++ .../sagemaker/text_embedding/__init__.py | 0 .../text_embedding/text_embedding.py | 216 ++++++++++++++++ 11 files changed, 785 insertions(+) create mode 100644 api/core/model_runtime/model_providers/sagemaker/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png create mode 100644 api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png create mode 100644 api/core/model_runtime/model_providers/sagemaker/llm/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/llm/llm.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/sagemaker.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml create mode 100644 api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py create mode 100644 api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py diff --git a/api/core/model_runtime/model_providers/sagemaker/__init__.py b/api/core/model_runtime/model_providers/sagemaker/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_l_en.png new file mode 100644 index 0000000000000000000000000000000000000000..0abe07a78ff776018db04d84b7cc0ceb1c9ae81f GIT binary patch literal 9395 zcmai)WmsEFwD$uo5Ue;9in|tCpm=e2cLK$o;0{Gv2<{YjFD^wxDG&+;iWj#6MS{ED zlylDeKF|GfKP1_EXYH9aGi%nKJ->gXnu;to1{nqb0Kk@)lhObHkV_HAAAzWdZ?@^f z7laapHGjSiJ(WauH znX3hb#6i<4EOjqRsbqeQBcRqg-{O0hg~XB)v8o&}06oFbhAUveW?OQmb{1VDzrAYr z6ERHf>-n?AX=4wiA_}iM8%xE$W>Zhiv(>B@pt!9iw~;$OFRL!9=lJ&f3!+;&ej^w6 zBL}1Tp`lLz-CkXhqyS3*3c#HDTR&wmYw&S+){=G;Ork3^hX2daA;qixA1^w?CFrS~{(SY!O*+De_ zG5^o5kLA&v=0Cf&;I+jes~NWTt(ovxh3EyNZEq?*Y|kw;u`i+!J(bDKYHq+dWxM9 zJ)fyv9BDh!8&Y8Q6e-+C0$A z(I6z@hXUVrS)s+kBjtr>Jx{IG{~6yk9&x45#e8h{NU|nAYN2^aut|DTwqYHue`>?( z&^g`w$)0!q#aWkJf3pk<_J&%=OZ+572 zuUXe|*`u#L)e-!6Xygx#IB^2~F)j-)5(!G2_FICdyn-LDYvBG|RVsTS_YJddn4< z1}hCtTXrvYHNc-9q^9M2N#25-0(IPUf6z*%{Tm(VLvFT0G%y1)M(Q%dwfie+ymBfu zL^l<`0mf*Hu(AGiaDeMjVot^8u6;E$z9Vs&N|ZF4vpnfn(D{#~^`;BU#6U^204PKC z%?95r&0|c16_9}wHCEatC5PPTA54Gi%;(4My)6jhnyGnczWisl zap#y5q!(grKKexSHE)9lAj_C2q+=v*^!JX)+~|}_q=74k`k#%vyv>Atw#;)kOX9lf zCs&-mmcruSXOzK^^o+aeGd%;+IHdm8AU1?*MPM*j%CBru|HIy)KQbxabaVzQIOO&> z;XJa~lVI#T)V%4&-Aso*t)&UlKC5=+b|SbX67Mm&kX5A&>%f{)XPaXOFT8-z!{b5?~vFEu_?HyG%&}8dv zOf)$$++#;r$R3OE-BM1x8GeZmpr(OAuCz9J=ve$l8?T!u6Z}Q)hpXUgRw#IxZfz0WlIDsyfoX(?b z+0P4^J7lnaT)~b(C*GODE0(G;==_yrsFC z7;k*Mv^ns!_LSJgt=MZ8163~TTDia+DNuHxwt$X0im~!p{X3NyWg*h>^y{gYS@3E_R4n?iUje@z2oxDS>iBsp#`#M$P;0V@QvB~t*$jE^lu$DdA8NSU%m(_wF0 zEU5JNSGm*5Yz|z|Yx7)I!P+Gi1D;FRK6?c-B4CjxLnlI~4ge@x`rl@-Ia zyCzBX+D~fMiRhGu0Lno&<#rPdhbI6SGB7S%N@1MiCT$B9%dQki8{76aPX7f7fk)>PbN^MQn4yPvh< z2o=yTf=^Ql_He{&Jg~aaNZE9;lN4amCQIgqOHU>kM%jJ0z6mfUoxtR!-wK%}HN_!a zGMjHGpaxhPHN$t2a^5RTX-UhU#|$COQ?tQ>`R{g0H{6N~$kzo|N0$>AW1Re8DHf|$AJ zc*gxTD!YHB*OjI-hy*^pgfnw5#lm`Fq&Pf2IyUVc?n%U(<_ zEGogHZEKQ{bnpE%-~yCacl(D)jHmBCOD5U8(GR^=0|Ssu+u{4TTsuCqEbJpJ3(0d` zS#s+_Z%Vqd54t$4!j_p^$ooP^Y1RBNP9IPu*}B@dT@=}+A+nyEgIH;#^~ z4H@rIHn~|eZfH)VUm_ZtMPs;{<4RnAL27}jfS?e?P{Q-+1AKPQ! ze^FM5w%3hUERvXn@4ovaBB$juL|pJADvmQeO3zsob$aQDB;Xm{$MB#Xst19?(jauB zFBrGq>ksRcdn$I4UEsx34x)6+)OVDp6>H1cSat^GTuu(e^*X}1JU59L{nA5Wpw77d8515k3gYL-E4xbd@qb~Mt z??UdRJF&VL;tWG0#uq)%cOfOUzmn#ssmF62$o3_@R=`!yn{W7yQD;eKRRX6OIKlNS zXiAzd?HkowPNvZ5XOlt$r>2@wt_)KQrAf!wp9m&w@aj@8MDHaHils!S48d?6jXwC` z4y9)Wa}HJ^6Fs+3KlJ&qa96)M@r=P6jkfj1YS4{ z#}&N3qxw~@eP!PyRYwKhS5|s2D8nQHle2P4RVS?mv+yTP*#dmzz@+B2_ z6hon2VGqVOvHooF5@Q;)g~nf}Ve7|Go}jq=a=R2t`n%gDrx)4IxvtcwiUYaB zdaa*;Cvv_r8*kuk9B9#KkN|HP`?mACNS7hupg+VTWID;pwy#cwoThhMN~ZCLyE7AE z0h-ufj2p-Xkh3=J*X)ezx$<8IYJf)%XhHk`EtS);nF%j57#9c{t7``gU1ccRn z$g`b^Iq@2VSE4UiQ{w?$A z=b_IBkUSq6F3S}X&g@un{G{<&BH3JJ!9=U0tfHS#6px2k95**NZ_l0LuszjREA<|C zU6*5jOSaid+MK+t;3&jU?bama^$+I<^iU>~GPx z{hRzy_5dAe$1;VdPA7xQ5(yNb=|4g|C9bSOR!Y@XvlX|~?XFTTeMYR0%?~EJspat1 zzB17+N+- zm{jlf7&;9GC{PnVdd2)fK{m%i$$W~G1t1}y{36}G*L)ZD&+b>XyNb^fH3TZFm9R-nt-oA%6RYs&|f4eFBY;Ba} zN$au58_>8z7TRW4KSTDjm@_xFwhpb8bHERjtPZmbt6bn=1U;(?4zO6bn~&&Kp}%A3 zU)SZmrojbssip z)P&!^zmUvG!4M=0yS*l#H->>`37d* zasjF-(_hNLT7(7T+~OWT_KxfPWQz6&Z?qGLJ-3QSid>b^G1ywo9IuQ9rHtl&#r#>S ztSSvGXr5W0;px%U#ZX)_Wm)pMQeN|+6b8>N=^C<8+=fJ3jh1PuE_S`VHPL=s!n7G& z+#-cLDH6AFl}QBdAb@V;mx!J%EkzX)&{|FgI`Ucwh4& za-fYw{`u2F%G^+uRvGOOyK$lXHkQ7x^+Q`~vdoCbI{G`$+yb5#+E{(I>^c)X(I6z| ze9=(nClg6p8Y->o6lNJBOiv3*tT`cA02SIG&gmNTv2p?z51lc@mJ<+f+g~Ys`-o*# zfy?hYa_D{1kD%@O=*zhDo?wd7nI4|p#P`Mr_ys4vY-3XYdb$n^@g`H+^^9DoeCd2347Yq5 z-?YZZP~+h5*2LS)Tj>0~JA1vjzW1>Jz|y4H*tb4nsr)XySpWSMWMQ=iy- zuXWcb_?f#>5|srsDu@u5totL!^bOS;IdwkIQ%&Z*PFufm7`ox6dvK!B%S*)$Qe07r z%RGh0Lzkmhvj#KPYVn$;CS7jlk}vp>rAG?K^3Qeq=X$`36!doZqT1Nk8Kj9O5numQXHP2o^;NIxXjgD z0VC8>qKeL_)e-}+n+a#ObXm#>rlfS%*Jp{cN#d$G=~(kOwF8vTD?ch@WLDQpzeL6aIWcejEpo`!KFlbbS@in0;N|S<|vOhQ4a=x6{#< zL9~5+I4lPD2l3th=@`VSKPWFwS6d@2A?{(C#bJAck-Ls-ub{lb-kdOn!c?kLV16HPYcD>FqTGnVhu@u zzBy}9>0@j*YE=pV@$j&fUaE&b)Gt2oJqFKi8Nql_@87%K#m|UP%D4{&l=Uv?4Ayg& zVYqiB*11}84t+L4^Z&)@qKXrpuT)ENF{t@5(f$lk}VQU z$#mx^@9gLN1f^1;oEB;+6V`shHi^U{{#lm@KxG45-#~IOSgw#YA*K}o=-{%`wiO8Y z?55iJ+btVcDOi>*qB$Wvfb2=OSASP#%J(<#j^b64e9B{-d_n20VFtT$haeA$<#}XA zYwOgwA{r#bQSF+gCle7(iTZt`6i<~)S=%IZx#8?RVHpw2K9qTA8{c@n+AksPrqaB& z>NYU7w=Oz`jTe6Hns!uKHrB6>Ice8F*!g+P`sEPt9s>JSZh(Fr*gweoyV+BAIX!{X zbG|*1WobBW zfjJA16}Li?UBxl{k5?6BBrKJWkaF>7odjCIn!Pb&I3#O8`{q_6xbG()ntH7Yoj*ef z8>Q0|-k5b~#Hz5Ku%5dfB%@7Q0TiipCa1O1H5h&S(6}3=R2RCA zeqOI1sDjv^?!wY~&yv*2QdxJz@1}qME;G1}5mIQ;)#7RX{{8*_6JXS&JC%tw9v&05 zS7ap_)uUt&{RU5C!**IbxCo0c^`)9x%PT^YP|E=&%TUmA>En?$0Q6lKytV=biW|q>PG`Nb#Jc z8ltN8ReThhKJsFsUYfVE-@qL%O8-ZXs|TW)&37nHL~F+QmCsbCiwhf>7jhCH6rzOu zZIB4tYzG*Zg1dI1AhiY?7mgDFLOV&NR@%4lX>x0an72OWab`PH=MWK1buLcwcv`d% zJ63m%>GlX~fPJA1?12@3t~9s8e@+fBe-dG;s&)bm5x5onuzqee>dkZ;S1FOqBDgiS zL;_z@eTgkXSZZkIrlhc1KvqEjweUSa;;vEnoi3G@q)X zPAV9Cb$MvzU%OKeL;DBKmkR z{c&(vDo5%zNf!#a&^8^GD(2AEFr=j{LxfHt!M8X=@HNf38S5b9R_Wb6CjsZ5MS6W$J~MR+2d7j%oIYc?uzPtS+zYJfE(E4I=KTDVsOQJdNBY zE9Cs$+4T1;|s5JeRC zkqV2H}CU97-UmLq2$0cAYCG&3uREQ)wc? zR(tO~R;hCgMGPhF!?YyjeAMn#D9*gq z$^GbvZuK!|wNN z;H5~rvAMDoESF;kezEm{di#gd_vE*-{&;Ocqr?2Y;ueSt_$wn?Gs&fyK9 zJ9)`f3*qQq6Rpkb=(L~) zmO}O+iuL}HGohN&*hv%ealt!s;>PR!I@8(c&ABLGcUk%VgLifi6QAx{&2#C$)UpmU z1PwAAi(RNk`%QVzN9DjDL)ZrWK@F;)hS&`EGC|Fu8k={nq+zdaO9yIFzqKz?7^nYC2=}Kns1G&B06pOhyNXy+gmB^f>_fEH58mJ}69SG}~T{j|QO&Q`Hh7 zLg}hTSYY!TMZyS6gtnXO$a1~^O`BH=2ZA<y2uP2<**)v_3SuX)YRCebg!HdW~0P3u>}f@YgFJHReP`L`pw!!?DO18@j46v{tTKACAFo8Tgcv_FLNr750v3X3W`fTn|~p3D6)VT z()?nUp&_76zz@&+RCMeMrLc{9I-+0Vk0O6m$^-js4St@n4d)Q_Xw4wcJ3`yS0ZN}9 z22HQ~5ZpL21r?PG=`BqFhf2`#MRy$Z02;$85XAIwA>gv>+{W3IDRZ=oKNEj3TLl9%bz zzY4;N-1VU^+P>dX>@OgbRqoX1TpVOX8XKp7M4;vmGI`U~igupKA5K{p+tBcaL7{Tu zeFMl6*Xmh0f{2yc`3KzMw)-5gzSes6k`w}yr>GHwsR4ke`hT?m5=|I+gQ?UOz}|A! zKigxm5#4dD!16jqscMpwsHKprO^nd|M}UZGkkx$Aj;Xcy`&QlUrcLgI$4}25t+4+vZN3ci~0^e!r5H$(0(0nSJkQ5yxbC#DHp$387r%c~PR&!E^J3 z^rcY=0Vp%=4Tb$~XnTP_>+U$gU;L<+W#`rmygg~!*G|Z?bQ5LV>U}HdmPaYX;xFZ_ zGwwxzVBr^|4cw2|zDOmD@)H2Tz}kAtp2;$WF`p@7dbraX&s4aL`~`mqERTNZg2umk zau+GEVt;_1JKmoG1=)Jav#(!ZGH(4?Vg090#%}m-_nz&UCxU~IHX2F=o@ZT}q*x8E zF4tUieYGA1ULY^Jbp>Q+KjMAD(mFRNN>wLYYEvA`&d)e<;)H8x6R;|X#2c)2X+t?? zR1znMW{3NjMRh@y-$L8hvh-~s|9XTmIpuZwITEIV%n2=rKAd*}^LC2Qz!q9YnXxIq z1h!0mJ3dl#QeW^60*K>fQxzYS4eCyPWs}>v)Vg%z@(}tdV*S_gc=3&&gOzA#wu1TV znmyysc|wf&20>f=^DlZJwr!33DI=^q>%GJOALdJPs8D1r5>rE}UIK94;6$Yqm_(=|Xy~)Yi{Jq(=Wzv* z$+=liHlfG8kZTvF^c6VTKUxU6@eZyCk~6jIr2em+G4~2+(tjt|xT-LW@lmz%|CA*E zDO3I#kH4J%u{ps1Lviw-kmWzS|AZ`I|0+|klK2T>7!&u zPPNyTv@tQEbT;W9nGC%?*msrn4bpEGwMfPEP0N3UH9A?y-QktBNO}MrW{r;q$<=^b hi=clrH(3v84ipNJtpSyN2o(}QURp(}TEaBse*pS#84mye literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png b/api/core/model_runtime/model_providers/sagemaker/_assets/icon_s_en.png new file mode 100644 index 0000000000000000000000000000000000000000..6b88942a5ce27cb57e27ce876cecedf3c5dfce3b GIT binary patch literal 9720 zcmZ{KWmFy8vh~J&Gi0z?1+fFdU=ss3iQ-j+5z?Ax`=Z1wk>fwmS?5(5D0 z6Oo?GVBYS@EM(P{002K~03aj+0C;!{g&YC^o~!`CktqPcp9KKmx_sdk|AI{^>`fIrpVJb*j|;lF%!2-<%zkN`lGEdcr-jQ*SbbI87} zKb*fU6d2;)6~U1ILTiDc{^e``p|x}Jn!XtX7g;@b000T|&w>Ete8d9)AV+OAbv<;I z6!|Tj9huGEIh$KD`#8G%VF3hv_}^SdOAj+LA4dl#cYYrs&|e7tH~&wV1w{52#KT?) zq^qP#CgJR6Nyg30%FGH9MkFI66LfoL#jh?Y{SW=^P6%Y<;o-v1!s6}i&Fsy=?CfUE z!p6tP$HL0a!p_e0hG25{b@DLtVRCY(_`Ase){(Syw{Wv{@vwDvBKuR<%-q@2LkI-= zBlNH1?>;?jt^Td#V`i`LFPQvvmKR|F_J42LGF&;%004_8|Z2 zi|wDb{xkL;_&?L}tJ?ZlI_OH;I$Aop|Cxo2omG(K|IGY1Q^MK7*-gX6%)(OmP3SMm ze}w)U{g;m3f9m`%>Ay4oFa=ruyzT#Z`+sk(zoKt@OBhj*L6 zG<_gW0)3OTmV6IZP6O>(s!%3?;uz6VjN-C3P~?MI)Ypp z=@sSjKN7PwClsk0=ZooG%>_24sT6FaG8~cs5a#HQAkpaN(9%MaZtqt2$)_&&EyqqL z{jyvqRrbgLs=nkPK|IK z@j>&b7kN33i8?>3vdnN_-|K!YnFwYx8~-KZz#OO6VjOM!TV+MSPx-9zY$Z@TGXJ*Q zd0b(?k#qX$oApXyhDc>**W8zHXYo@HnV-DVyus{xd>z0vf;Ah)f(=58mBCYrDDIQj z5pR=#GnJ*QGoddpB24WT*S8Yjnp9nPX4NxMBZ9K_wJCMeb~;@5%@L$1<)<&*?fy)! zwZZeQEm?$IV#mwpSTk#~DA*8HzSM0u5$c&W`OlbYs!UWCJ%mnB{F<~PBU8P4aXa$z z+!eL~dZs2I#AnL>lP`o406Yu^Sw-_6QD-PW3tFHd;&O+P?RINx!6qwY)54ZAMO=@g zstlfLMU9$JJHkS+SNcGRO+nN4bD> zyW7Kfju;T&312~>cw8uyd4z_Sz7SruOa1QAP?UYtBJC3>Zr0;oX-Y`^h_F%6jf-J zIX)|ymUZXoR7ewHH1*gtH!(hINc>$c>p!kO$r6WF%#yD2drE^{N|oHs<{G@t2uk#& zksyc-^M#9cZ&inWLY1o8V26!C9JP&*ZLzBoTiU7~SX$69Z7=LkJ$PN_HH2KW3;Xa* zTCr%iI99-WcUJq~bk?{JpWcJ@hk&B4_@3GVD7ikL%a{Vc&2iui6Xicfq^5mT%B%d^ zV#~h#+*0!^fN3OZ1H4*lpwP@itl1h9(osj5oyjxvVe7W($*Lc1CCu`HG`X~AsWR@DP(bsYxI3| z#muRdGFz-$%>-Z1q)z;+G`n0-IBHO;;YK?%ohCComoIs35s<^N#tm!J0M7bBJ}Tx3 z68Akdj~=I(-FLx zJC;tH&3^jYuO8V&jtU`>T7)MX;7_+Ea!(`Pv4<}#6)xlN#6r*yzb^M7jddUTVxFO0 z-0xBw-evVsf3sCrN0cGvx7rD1V)LoLkXADL@nYSXAHyLD3{-7WgE(AI;mYSZA<^5o z6el9~#ZB7+6_0NS8jro+dWSJ1*9~IC#EFiUag?v_CG3N~p5Z6P=Y9{t(8Gj&`1y%{ z`SOaAuX!rXjVVk?RIfGYywHEQPFxM}6+DQ* zq5VLWQ~4=3?$=S1*|9E3MlZ(sXsEajy+ZKveSPun`BHFZnY)ZbOb@ybGlDqz-i70; zfd4Ed)2`37ptNjS=(~{H59KXC`*PV;2Fz@}U%{EGrWkJD%TiVrW5mP+a}O6hx3kQF+&CwQq3~uQH$%er31v5z&_uXlUM8m;nmOn(-H?tu@Zz%Pf|pIZ3{R^ z7VT;0WtQ@g*S*h$!hSy_8qN839Q}oz^lvkhg?)PrdoeX*N2Be8eoO?arY{AoeL`Ud zEhb+(w?!ipY92+y4ssgDwmNqadPS_xCCDiRS{cGbrw8ZlF;Lh69>@p# ziyqNU#r-6NPX1gmUfs!9c6o;N-+oHqXJV3ejQ*b8N?q_LQD?%7BiniX9QBDd27gbt z=pE1O@9@z=k|IvBBV+XUdKSB`bNCG0a&x}l$|?_b_$C8qi^%RZ5rfxW6`nA^c$^=S z`Ush-P+F?iN-|t{@;KgcL)zYf7X`%TSK4Cp`JBExz9!u%%@xmbRb53+RJ3j-%vRCH z!ybC_k7Twv5RCQ3eumfVYcOkSn}9`?WYRRqA8W+RHIEm?n0*+q^6E;Cpt;2>c}mNf z;Z=$!=1@C7RmeEmK!d}y{y;R+JTDq{8vl-fP1Wc#kx4y-~bAaq{SY$7cA z@8EITtPk{w{u7Qroq`>;P!^U9-g~$@wwMXxv#CbSGMPs@32`%@6apc!O|f;OCBP0k z$;aYQD_n{jTwr5UUMBrH@CNM(Ksk8_ytT@uQjqIxF0; zb9o?w%;{wqbV0EOtz3qTTv``zC%S6?frPPxQXSot#Ou62*W+T@i&)%#w)w*~AN-Pi z*y492OCW99Gcr)?L6PyLnq`cF9=i$h2?ts;5FZ*SO^4S zZofWK4CLH19*7AI&u>Q>2mClm5-kZ$$tq@$W-(;g!MG5_dw<1AJCui3>bX-~KJ&ah z0&4a*Z-!|5o)f$yfaJ(JxSOA1Dj3`1a+5u3^$uevXBLSSUu-4z5FM1v2>Jap?%4USmKURzz z7hksBmi2={!87zzxvzEajn7JZJh?k9T!#=afB6agklE^m)sUR9NnEDhF1} z4hEu5l1%FpG6$-2UX*5%`8$^ldt&fLjUzPQ56QJ(HB}4WeKB4ajLba3WVD*hRI1+E zUGNSf(NQ^JgLxll%8HrM%xg188cML_aJlVX(c*%A;|{STBVQ6Z#Of7H+W!AnRxo* za6~XNIBpN7A`v&YUBtF!+B#viZFWTP(m9abA8K*M{(%syOi~FQl0;hR$FG^%J?QL- zqY(baO@m<4`p#od2yLoq8@q&0z=BQtCrmhx+0}2mi>`;PC(L+4N@&OnJDc(RPZOvX zzQ2St22wN!`Wq57a^;47uOyvHCqoOc{L#|~r76Ju{*9;UrPHH!dDfzyBpQ^<$!*Zs zY1GcDF-KP)K58*@G;FcYSjw~#$eEihwyKnW#hn*p32q4V>q0|o z(BOfC4ZN?4F}{n9_V#WFaEyWdtjBx1BcrQB_JtEt3fNU~NL?j4+lM&)DH zdJWQjY=kDPXj=^piD5gREF7DzF)T9c*&C|E4}IX3q`KAZyqzl!coz%I3{=w+E$qrv z5i0em+;%He%z5Yv>fP?2A>%#3L?Kiv$6 z?v0)XL#PEu~@H5pypH_4Ik%!PQ4=MM>%JmK9yg!!t~su^*)!_Ti+DLoP-Wi*~5 zU-C#4P}8zQ#CI-3Frd6gJTdtb$b9R7IKnq5EcyI8I7K;JLzzF1G7d0(_mS6Wz;?A{ zgF}xF_j{AHeuu{u9fz)odlo5zeuYqdE_OR~)K$SnL0`OmcBq?LH6p$}WymM}uCeH%3zF zyafu%;Ns&#P5jEbTb|R4{a+0_%H$Q*;Z-y-UmjU(1RN8pdRfI==F^}cN9SbX%DrYO zi+j|@(xQh_)@a(YJ|ON@sOes=$jiCMj~~>DO37@LiuC|gV^QEkQDQdl8#1twk!Q$= z#**F}E%um5a*AS#LPap!Sd>-paY3x{=cm}M+_PaO>GX5!?$Z&<_biahGrac1er8H) zo)Ae4TrKZ}6a-`E`acfU2i=Vs54O98Ov%7;^-HJU=9a_Q(J9S~T$*r+>1REIf6Hdv z-9cfR@HxA;UrxW9u9`eHawgIYi_XBLBVtoXIFaikuC@2|Yi0f+gQOgZ5Mo}vm?9Okhwy?vAHK=*auU-3FrBD^1Jf!x{+Su|(XNp%KE-Z9*Z*He9 z{jbCYomYC0L>!sCu=g-d!D}O}P4;jq^i$h$98*scXk;hzLbhlcp22ceTrF^qp)1tq z!0W2AvifCWs%U$US%b`#Zmi(im3N5{4m+fX;>Qz$k^k$pB+;4!;(B;+xlzwP@>-AY zCu=6$1Zz1k%{!?`a(m) zFS#ja0ywU$Ht05~M~A|1O1N5gwWF-$VMFSqJ};NPUn9OH5x(E}FkRj%(W6Slc=_!R z4hlRFP8>3^>w}P^>0;-PB~CKAFRvN>!dhom`qm%{lvD`H$d(UMqPpn$#B&EO(*L*0LCVUbat5Pj~rIig9ff+uL8{@XbWs&T>2U=Z&Gh zeBRP1z~UFS)B0Q?A`1c*Yt%nIU3Yov# zQ4(;9geO`eeSjgxnO`}!*MGMZ@oc6RQPkA`rmI?^WJcXkxFm1@$pdPB({+&siX+tz zY5f*!Zc}gtMTj~G#iToKw`nP@pr^)!%5b;jLc`Z_sWv5lgM~l@m{*abVWaCg$a><# zPiM+t3zm@AyZZ{!=ZY92^hE(WHea>EHg20W0QxbgWNg_~lB6c-o8(i10}>@HzM^?~O#W}+h;4pa z0e3}QEq6xuW?&a~o1P+FdutW8z+!gkGhz6cO2C=`dtknPG$L_K$Lfq-$(&>qH(@xf z`1I+u-RYx~D%eU&G}P`Gi2x}}hcJE~OZ%m#k%Vwz(mpph#V4Iv`&Q_1Hn z)!6L-gI9g0RosY3fK9-ESn~7DO2F%63P;_d*~z@+*N#Iu6_qZ96um+6#4-$HJF^tQ zI(n7icS3~g>6#ivo;*rrDdyJH{Z-5S+KRIFWwpNV#D@dwF0AW2UHb6LYm*skmcE0< zO7x2JIA_MU&ck*4uoX|S=C+kDPK`_V4!P9=AK_d}jz4OT6{KE#d^Bq&*xj2q*2>6e zQM!HRH!Sd|pB_jV>i*QnhgxRX^^sY@P5>s%Nq8i8quskK)tJxD{VT7(|4F(aB0uNI zB};k!ph_9Yye3D3Zcxi9;bdk{`C?m{Zd216(KDW35dxz@Yj~sf@N#f2k|TNo;x|4;+OTh!x}6SW>`5ZTplyg)5Z(C&R7`_XPk@9d zv|j{C&drz+cvCj zN=NF6ELR{@DnW$9F~q%_RU8(z8KvvFLal-8&Mufra^_{&ilrJo8u~gobbCm`p0C)h z(5x=S=`&DBzdAfjqVTkIs~3hDJW{+eOOXsD$Xj`_GTeCn0Acc4%%Yi}5?}q^9DQzY z#=GcF;j@zC^VR@PFV?;}A<29v+V2|ge|q@&X^I|BUN#$_%*v?wJdlC35Xo*}wE-Eua-G#n za@Tl@egRgswoVoHk-uGdZFRxw{pBwx74Jh(Fcf5BnZvghjqgd5-{Wa}eU>o=jNmFc z@~!uJGp7rGO`im=*3Ii(f0+CF8?M&Adv(I?g*KlzihR1_!2R=$`g4)@@`}I`_oZtU zuJe;(Vd7a&%?E7)u@k22k%TB8{7Ll8Yr1=3I=1C6X5sQEGa--OPxNuUc{GYO>&&jw zqYdR(mQ(#xz3OjYt}BGeVdZ234iSJBc=UjZU0$NdYC|k?FDk8pyr>Lu`sgMid{j=> zA!b17ZudYSdM-?$h<4ZT$_w$A7RWRxZzu2%`$puGk?0V4C4P2Wo7_8+UujN?=*l&W z`V?1yCKX9UMp?FBbr*2Qm_b$|Me1m#2`FfWc!Q<_WchmS>b2PKe($9JEMq(hqxp>X7XpV>*&%KY#e5G5`XXsX1-o~c1XkG)a< zkPNBIDC{m7Rfl|mc}B>O`W_5lghRN%Uwugj4xrQC_S<6b-`swRF?4+GN&{x6&=fHa`FwTO%Iof7egAA zh9KO}01bpMjH})|^F6(;Ga#{%fwT6v>8lnIfXHnf8*zMJZe$^a%)#91Se`E58JA^V zr;)s>6rjb2pU3KlyYGX)EjBzLr)>B4)fTWfCIX4r<6 z&UsfgeMg^=Z!v&=;-%nXBf+yzLNAmug;V}WlU$};SYwifOf33*!fEyJTofrWfOUjG z2&fnuCEuNrE!_$KkbwS5Sx6ApMu-?Ac7M3yOR=ML|G~E4Co`6kI)XAXqlUgHqj;s#T64N9S4qF7FuTSQs zy;01kPBl3c+NE@VXT*^1|(PiK$T=9b~O`bIIXUQ*sy|8TJdr0Gj)bobN_{ zmGG}WxDkQg88JRoSf3hGQ^79DRQ0iP8glML8N|V@RbEk}(Opw_l%&TpIP%8#-25nH z1f)vNtZGG&WA79^G$7&%(`(r_Y2*?vnDECCldON-Gi3x7ktTCGsFvdnqyef{FH?(s zUZSg+{Pe%+PM(eng(irFnU{naP%1*&7k=wlUT}a0jChy~Y`26KbYUF0=B(=Q^;V#b z!SmTc20J;cKoZG^&kZt!@ubu(BoQs^j;dPL+c~gX#K30+_F7ddS zBmZz%Il!Sc!MigfT7k5lJt(7J>fI7^*1Jgy_K&=&sr(45pYEtnS#;^Xd$UkCA5rCg z-a-JQL}hdv9I;kGZ|7p3BZh;150HCskQD% zOFka)1C8bFKzCW)yD-%W&I>z5*4wfAN1=%&IG*gr0<~1t+!HBA$GiC`ueb?aUZJF1 zJ$#T)E)&ZFl8zTVo$nUK)c~zzf;u z$ys zrty(&H^c^J@vyyYDCv#sB;+%r>gcFcS87j6{)b2#xNF?{x{K4L6vO7+AQ>vp&xx#A zo!yFT-$ibdfkMxerZJM@Hk!I~7wcM_E~wUkBL)8bw6vAJyV+F9405}iJI7rs^DnN; z$eR&9eIB8BJ@c26J?bFr*2?}x8;JpR@A}5yA2-8qrAbf4_r?`sj(B0cLNzwaJ!9SC ze2!o|C4KE)ol4U}FuhpOK{=yq&yhC?KW#atCbC7QhJRt%Xl3@)7OX!+J{m?eSGbvO4B2? zKD+GR2V02>1R+5PwS_?;-~BAUX*ZOxt;0#f%Pv7-AB?x&Cu-2H81DUJ%Ou4>0^k~k2xINDO)wL*EG z{5cR{_OTfD!M*#^qV4tqZgDO;6;#c9@cmT;SM(A2!X1k17)A{%dsnuLik`9uR>BhK zvgh61uFt|VP5JI*ef0PMNRN^cQk<8~EFrJ?LwXdYJh{-CTJF+|zzo_coWCzb-LqA! zlMf8cPTh!bdYUW85Zvs0kd1P6f)e;;3loVZzD?D4#y+3IAjJDE*625BtSI}hos9=X z(LBjqgdnr;Zr(Cd#3HSa&Kn|^Fe#nSi2*2q!Zz{!>cCppGm*<{47Fb;!3jVz!$b%S zSv@zh#qM{vdxyWSre+qlwqBwz{ezR<|1SJRsyJSQU;JECaEVIu5-cwf3H#B0(e2Ow OOmb4nk~QL{VgCa)i~r65 literal 0 HcmV?d00001 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py b/api/core/model_runtime/model_providers/sagemaker/llm/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py new file mode 100644 index 00000000000000..ca3c43b825eb50 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -0,0 +1,244 @@ +import json +import logging +from collections.abc import Generator, Iterator +from typing import Optional, Union, cast, Any +import boto3 + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + +logger = logging.getLogger(__name__) + + +class SageMakerLargeLanguageModel(LargeLanguageModel): + """ + Model class for Cohere large language model. + """ + sagemaker_client: Any = None + + def _invoke(self, model: str, credentials: dict, + prompt_messages: list[PromptMessage], model_parameters: dict, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, + stream: bool = True, user: Optional[str] = None) \ + -> Union[LLMResult, Generator]: + """ + Invoke large language model + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param model_parameters: model parameters + :param tools: tools for tool calling + :param stop: stop words + :param stream: is stream response + :param user: unique user id + :return: full response or stream response chunk generator result + """ + # get model mode + model_mode = self.get_model_mode(model, credentials) + + if not self.sagemaker_client: + access_key = credentials.get('access_key', None) + secret_key = credentials.get('secret_key', None) + aws_region = credentials.get('aws_region', None) + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + + sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=sagemaker_endpoint, + Body=json.dumps( + { + "inputs": prompt_messages[0].content, + "parameters": { "stop" : stop}, + "history" : [] + } + ), + ContentType="application/json", + ) + + assistant_text = response_model['Body'].read().decode('utf8') + + # transform assistant message to prompt message + assistant_prompt_message = AssistantPromptMessage( + content=assistant_text + ) + + usage = self._calc_response_usage(model, credentials, 0, 0) + + response = LLMResult( + model=model, + prompt_messages=prompt_messages, + message=assistant_prompt_message, + usage=usage + ) + + return response + + def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], + tools: Optional[list[PromptMessageTool]] = None) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param prompt_messages: prompt messages + :param tools: tools for tool calling + :return: + """ + # get model mode + model_mode = self.get_model_mode(model) + + try: + return 0 + except Exception as e: + raise self._transform_invoke_error(e) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # get model mode + model_mode = self.get_model_mode(model) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + rules = [ + ParameterRule( + name='temperature', + type=ParameterType.FLOAT, + use_template='temperature', + label=I18nObject( + zh_Hans='温度', + en_US='Temperature' + ), + ), + ParameterRule( + name='top_p', + type=ParameterType.FLOAT, + use_template='top_p', + label=I18nObject( + zh_Hans='Top P', + en_US='Top P' + ) + ), + ParameterRule( + name='max_tokens', + type=ParameterType.INT, + use_template='max_tokens', + min=1, + max=credentials.get('context_length', 2048), + default=512, + label=I18nObject( + zh_Hans='最大生成长度', + en_US='Max Tokens' + ) + ) + ] + + completion_type = LLMMode.value_of(credentials["mode"]) + + if completion_type == LLMMode.CHAT: + print(f"completion_type : {LLMMode.CHAT.value}") + + if completion_type == LLMMode.COMPLETION: + print(f"completion_type : {LLMMode.COMPLETION.value}") + + features = [] + + support_function_call = credentials.get('support_function_call', False) + if support_function_call: + features.append(ModelFeature.TOOL_CALL) + + support_vision = credentials.get('support_vision', False) + if support_vision: + features.append(ModelFeature.VISION) + + context_length = credentials.get('context_length', 2048) + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.LLM, + features=features, + model_properties={ + ModelPropertyKey.MODE: completion_type, + ModelPropertyKey.CONTEXT_SIZE: context_length + }, + parameter_rules=rules + ) + + return entity diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py b/api/core/model_runtime/model_providers/sagemaker/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py new file mode 100644 index 00000000000000..31bb4285f494ca --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -0,0 +1,189 @@ +from typing import Optional, Any, Union +import logging +import boto3 +import json + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + +logger = logging.getLogger(__name__) + +class SageMakerRerankModel(RerankModel): + """ + Model class for Cohere rerank model. + """ + sagemaker_client: Any = None + + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): + inputs = [query_input]*len(docs) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=rerank_endpoint, + Body=json.dumps( + { + "inputs": inputs, + "docs": docs + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + scores = json_obj['scores'] + return scores if isinstance(scores, list) else [scores] + + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) \ + -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n + :param user: unique user id + :return: rerank result + """ + line = 0 + try: + if len(docs) == 0: + return RerankResult( + model=model, + docs=docs + ) + + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id', None) + secret_key = credentials.get('aws_secret_access_key', None) + aws_region = credentials.get('aws_region', None) + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + + sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + candidate_docs = [] + + scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) + for idx in range(len(scores)): + candidate_docs.append({"content" : docs[idx], "score": scores[idx]}) + + sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + + line = 3 + rerank_documents = [] + for idx, result in enumerate(candidate_docs): + rerank_document = RerankDocument( + index=idx, + text=result.get('content'), + score=result.get('score', -100.0) + ) + + if score_threshold is not None: + if rerank_document.score >= score_threshold: + rerank_documents.append(rerank_document) + else: + rerank_documents.append(rerank_document) + + return RerankResult( + model=model, + docs=rerank_documents + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + self.invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + The key is the error type thrown to the caller + The value is the error type thrown by the model, + which needs to be converted into a unified error type for the caller. + + :return: Invoke error mapping + """ + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + InvokeBadRequestError, + KeyError, + ValueError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.RERANK, + model_properties={ }, + parameter_rules=[] + ) + + return entity \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py new file mode 100644 index 00000000000000..a475e2a2c689a7 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -0,0 +1,19 @@ +import logging + +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.model_provider import ModelProvider + +logger = logging.getLogger(__name__) + + +class SageMakerProvider(ModelProvider): + def validate_provider_credentials(self, credentials: dict) -> None: + """ + Validate provider credentials + + if validate failed, raise exception + + :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. + """ + pass diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml new file mode 100644 index 00000000000000..968d7adde130c9 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -0,0 +1,117 @@ +provider: sagemaker +label: + zh_Hans: Sagemaker + en_US: Sagemaker +icon_small: + en_US: icon_s_en.png +icon_large: + en_US: icon_l_en.png +description: + en_US: Customized model on Sagemaker + zh_Hans: Sagemaker上的私有化部署的模型 +background: "#ECE9E3" +help: + title: + en_US: How to deploy customized model on Sagemaker + zh_Hans: 如何在Sagemaker上的私有化部署的模型 + url: + en_US: https://github.com/aws-samples/dify-aws-tool/blob/main/README.md#how-to-deploy-sagemaker-endpoint + zh_Hans: https://github.com/aws-samples/dify-aws-tool/blob/main/README_ZH.md#%E5%A6%82%E4%BD%95%E9%83%A8%E7%BD%B2sagemaker%E6%8E%A8%E7%90%86%E7%AB%AF%E7%82%B9 +supported_model_types: + - llm + - text-embedding + - rerank +configurate_methods: + - customizable-model +model_credential_schema: + model: + label: + en_US: Model Name + zh_Hans: 模型名称 + placeholder: + en_US: Enter your model name + zh_Hans: 输入模型名称 + credential_form_schemas: + - variable: mode + show_on: + - variable: __model_type + value: llm + label: + en_US: Completion mode + type: select + required: false + default: chat + placeholder: + zh_Hans: 选择对话类型 + en_US: Select completion mode + options: + - value: completion + label: + en_US: Completion + zh_Hans: 补全 + - value: chat + label: + en_US: Chat + zh_Hans: 对话 + - variable: sagemaker_endpoint + label: + en_US: sagemaker endpoint + type: text-input + required: true + placeholder: + zh_Hans: 请输出你的Sagemaker推理端点 + en_US: Enter your Sagemaker Inference endpoint + - variable: aws_access_key_id + required: false + label: + en_US: Access Key (If not provided, credentials are obtained from the running environment.) + zh_Hans: Access Key (如果未提供,凭证将从运行环境中获取。) + type: secret-input + placeholder: + en_US: Enter your Access Key + zh_Hans: 在此输入您的 Access Key + - variable: aws_secret_access_key + required: false + label: + en_US: Secret Access Key + zh_Hans: Secret Access Key + type: secret-input + placeholder: + en_US: Enter your Secret Access Key + zh_Hans: 在此输入您的 Secret Access Key + - variable: aws_region + required: false + label: + en_US: AWS Region + zh_Hans: AWS 地区 + type: select + default: us-east-1 + options: + - value: us-east-1 + label: + en_US: US East (N. Virginia) + zh_Hans: 美国东部 (弗吉尼亚北部) + - value: us-west-2 + label: + en_US: US West (Oregon) + zh_Hans: 美国西部 (俄勒冈州) + - value: ap-southeast-1 + label: + en_US: Asia Pacific (Singapore) + zh_Hans: 亚太地区 (新加坡) + - value: ap-northeast-1 + label: + en_US: Asia Pacific (Tokyo) + zh_Hans: 亚太地区 (东京) + - value: eu-central-1 + label: + en_US: Europe (Frankfurt) + zh_Hans: 欧洲 (法兰克福) + - value: us-gov-west-1 + label: + en_US: AWS GovCloud (US-West) + zh_Hans: AWS GovCloud (US-West) + - value: ap-southeast-2 + label: + en_US: Asia Pacific (Sydney) + zh_Hans: 亚太地区 (悉尼) \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py new file mode 100644 index 00000000000000..f58d421e718da1 --- /dev/null +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -0,0 +1,216 @@ +import time +from typing import Optional, Any + +import numpy as np +import boto3 +import json +import logging +import itertools + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, ModelPropertyKey +from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + +BATCH_SIZE = 20 +CONTEXT_SIZE=8192 + +logger = logging.getLogger(__name__) + +def batch_generator(generator, batch_size): + while True: + batch = list(itertools.islice(generator, batch_size)) + if not batch: + break + yield batch + +class SageMakerEmbeddingModel(TextEmbeddingModel): + """ + Model class for Cohere text embedding model. + """ + sagemaker_client: Any = None + + def _sagemaker_embedding(self, sm_client, endpoint_name, content_list:list[str]): + response_model = sm_client.invoke_endpoint( + EndpointName=endpoint_name, + Body=json.dumps( + { + "inputs": content_list, + "parameters": {}, + "is_query" : False, + "instruction" : '' + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + embeddings = json_obj['embeddings'] + return embeddings + + def _invoke(self, model: str, credentials: dict, + texts: list[str], user: Optional[str] = None) \ + -> TextEmbeddingResult: + """ + Invoke text embedding model + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :param user: unique user id + :return: embeddings result + """ + # get model properties + try: + line = 1 + if not self.sagemaker_client: + access_key = credentials.get('aws_access_key_id', None) + secret_key = credentials.get('aws_secret_access_key', None) + aws_region = credentials.get('aws_region', None) + if aws_region: + if access_key and secret_key: + self.sagemaker_client = boto3.client("sagemaker-runtime", + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 2 + sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + + line = 3 + truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] + + batches = batch_generator((text for text in truncated_texts), batch_size=BATCH_SIZE) + all_embeddings = [] + + line = 4 + for batch in batches: + embeddings = self._sagemaker_embedding(self.sagemaker_client, sagemaker_endpoint, batch) + all_embeddings.extend(embeddings) + + line = 5 + # calc usage + usage = self._calc_response_usage( + model=model, + credentials=credentials, + tokens=0 # It's not SAAS API, usage is meaningless + ) + line = 6 + + return TextEmbeddingResult( + embeddings=all_embeddings, + usage=usage, + model=model + ) + + except Exception as e: + logger.exception(f'Exception {e}, line : {line}') + + def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int: + """ + Get number of tokens for given prompt messages + + :param model: model name + :param credentials: model credentials + :param texts: texts to embed + :return: + """ + return 0 + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + print("validate_credentials ok....") + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: + """ + Calculate response usage + + :param model: model name + :param credentials: model credentials + :param tokens: input tokens + :return: usage + """ + # get input price info + input_price_info = self.get_price( + model=model, + credentials=credentials, + price_type=PriceType.INPUT, + tokens=tokens + ) + + # transform usage + usage = EmbeddingUsage( + tokens=tokens, + total_tokens=tokens, + unit_price=input_price_info.unit_price, + price_unit=input_price_info.unit, + total_price=input_price_info.total_amount, + currency=input_price_info.currency, + latency=time.perf_counter() - self.started_at + ) + + return usage + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + return { + InvokeConnectionError: [ + InvokeConnectionError + ], + InvokeServerUnavailableError: [ + InvokeServerUnavailableError + ], + InvokeRateLimitError: [ + InvokeRateLimitError + ], + InvokeAuthorizationError: [ + InvokeAuthorizationError + ], + InvokeBadRequestError: [ + KeyError + ] + } + + def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: + """ + used to define customizable model schema + """ + + entity = AIModelEntity( + model=model, + label=I18nObject( + en_US=model + ), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TEXT_EMBEDDING, + model_properties={ + ModelPropertyKey.CONTEXT_SIZE: CONTEXT_SIZE, + ModelPropertyKey.MAX_CHUNKS: BATCH_SIZE, + }, + parameter_rules=[] + ) + + return entity From 536a1b00ed63b72c03f99a2a1329bcb022c1869c Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Sun, 14 Jul 2024 09:40:28 +0800 Subject: [PATCH 02/10] Add integration test for Amazon SageMaker --- .../sagemaker/rerank/rerank.py | 2 +- .../model_runtime/sagemaker/__init__.py | 0 .../model_runtime/sagemaker/test_provider.py | 19 +++++++ .../model_runtime/sagemaker/test_rerank.py | 55 +++++++++++++++++++ .../sagemaker/test_text_embedding.py | 55 +++++++++++++++++++ 5 files changed, 130 insertions(+), 1 deletion(-) create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/__init__.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_provider.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py create mode 100644 api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index 31bb4285f494ca..a1c830ca098343 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -126,7 +126,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - self.invoke( + self._invoke( model=model, credentials=credentials, query="What is the capital of the United States?", diff --git a/api/tests/integration_tests/model_runtime/sagemaker/__init__.py b/api/tests/integration_tests/model_runtime/sagemaker/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py new file mode 100644 index 00000000000000..639227e7450343 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_provider.py @@ -0,0 +1,19 @@ +import os + +import pytest + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.sagemaker import SageMakerProvider + + +def test_validate_provider_credentials(): + provider = SageMakerProvider() + + with pytest.raises(CredentialsValidateFailedError): + provider.validate_provider_credentials( + credentials={} + ) + + provider.validate_provider_credentials( + credentials={} + ) diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py new file mode 100644 index 00000000000000..c67849dd798883 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_rerank.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.rerank_entities import RerankResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.rerank.rerank import SageMakerRerankModel + + +def test_validate_credentials(): + model = SageMakerRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + +def test_invoke_model(): + model = SageMakerRerankModel() + + result = model.invoke( + model='bge-m3-rerank-v2', + credentials={ + "aws_region": os.getenv("AWS_REGION"), + "aws_access_key": os.getenv("AWS_ACCESS_KEY"), + "aws_secret_access_key": os.getenv("AWS_SECRET_ACCESS_KEY") + }, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + + assert isinstance(result, RerankResult) + assert len(result.docs) == 1 + assert result.docs[0].index == 1 + assert result.docs[0].score >= 0.8 diff --git a/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py new file mode 100644 index 00000000000000..e817e8f04ab67c --- /dev/null +++ b/api/tests/integration_tests/model_runtime/sagemaker/test_text_embedding.py @@ -0,0 +1,55 @@ +import os + +import pytest + +from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.sagemaker.text_embedding.text_embedding import SageMakerEmbeddingModel + + +def test_validate_credentials(): + model = SageMakerEmbeddingModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-m3', + credentials={ + } + ) + + model.validate_credentials( + model='bge-m3-embedding', + credentials={ + } + ) + + +def test_invoke_model(): + model = SageMakerEmbeddingModel() + + result = model.invoke( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + "hello", + "world" + ], + user="abc-123" + ) + + assert isinstance(result, TextEmbeddingResult) + assert len(result.embeddings) == 2 + +def test_get_num_tokens(): + model = SageMakerEmbeddingModel() + + num_tokens = model.get_num_tokens( + model='bge-m3-embedding', + credentials={ + }, + texts=[ + ] + ) + + assert num_tokens == 0 From 98105a9c8b451f5a0382ce42fd2bb53c6f0f5a18 Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Thu, 18 Jul 2024 13:22:47 +0800 Subject: [PATCH 03/10] Fix yaml - issue --- .../model_providers/sagemaker/sagemaker.yaml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml index 968d7adde130c9..dbcc3f4c265a38 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -114,4 +114,13 @@ model_credential_schema: - value: ap-southeast-2 label: en_US: Asia Pacific (Sydney) - zh_Hans: 亚太地区 (悉尼) \ No newline at end of file + zh_Hans: 亚太地区 (悉尼) + - value: cn-north-1 + label: + en_US: AWS Beijing (cn-north-1) + zh_Hans: 中国北京 (cn-north-1) + - value: cn-northwest-1 + label: + en_US: AWS Ningxia (cn-northwest-1) + zh_Hans: 中国宁夏 (cn-northwest-1) + From 4910e1501e7b163489d45f2d7df99a3be463b2a9 Mon Sep 17 00:00:00 2001 From: crazywoola <427733928@qq.com> Date: Thu, 18 Jul 2024 19:22:37 +0800 Subject: [PATCH 04/10] fix: lint --- .../model_providers/sagemaker/llm/llm.py | 22 ++++------ .../sagemaker/rerank/rerank.py | 13 +++--- .../model_providers/sagemaker/sagemaker.py | 2 - .../text_embedding/text_embedding.py | 42 +++++++++---------- 4 files changed, 35 insertions(+), 44 deletions(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index ca3c43b825eb50..f8e7757a969f8e 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -1,21 +1,15 @@ import json import logging -from collections.abc import Generator, Iterator -from typing import Optional, Union, cast, Any +from collections.abc import Generator +from typing import Any, Optional, Union + import boto3 -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, - PromptMessageContentType, - PromptMessageRole, PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, ) from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, I18nObject, ModelType from core.model_runtime.errors.invoke import ( @@ -60,9 +54,9 @@ def _invoke(self, model: str, credentials: dict, model_mode = self.get_model_mode(model, credentials) if not self.sagemaker_client: - access_key = credentials.get('access_key', None) - secret_key = credentials.get('secret_key', None) - aws_region = credentials.get('aws_region', None) + access_key = credentials.get('access_key') + secret_key = credentials.get('secret_key') + aws_region = credentials.get('aws_region') if aws_region: if access_key and secret_key: self.sagemaker_client = boto3.client("sagemaker-runtime", @@ -75,7 +69,7 @@ def _invoke(self, model: str, credentials: dict, self.sagemaker_client = boto3.client("sagemaker-runtime") - sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + sagemaker_endpoint = credentials.get('sagemaker_endpoint') response_model = self.sagemaker_client.invoke_endpoint( EndpointName=sagemaker_endpoint, Body=json.dumps( diff --git a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py index a1c830ca098343..0b06f54ef1823f 100644 --- a/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/sagemaker/rerank/rerank.py @@ -1,7 +1,8 @@ -from typing import Optional, Any, Union +import json import logging +from typing import Any, Optional + import boto3 -import json from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType @@ -69,9 +70,9 @@ def _invoke(self, model: str, credentials: dict, line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id', None) - secret_key = credentials.get('aws_secret_access_key', None) - aws_region = credentials.get('aws_region', None) + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') if aws_region: if access_key and secret_key: self.sagemaker_client = boto3.client("sagemaker-runtime", @@ -85,7 +86,7 @@ def _invoke(self, model: str, credentials: dict, line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + sagemaker_endpoint = credentials.get('sagemaker_endpoint') candidate_docs = [] scores = self._sagemaker_rerank(query, docs, sagemaker_endpoint) diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py index a475e2a2c689a7..02d05f406c50f7 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.py +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.py @@ -1,7 +1,5 @@ import logging -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.model_provider import ModelProvider logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py index f58d421e718da1..4b2858b1a28228 100644 --- a/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/sagemaker/text_embedding/text_embedding.py @@ -1,23 +1,21 @@ +import itertools +import json +import logging import time -from typing import Optional, Any +from typing import Any, Optional -import numpy as np import boto3 -import json -import logging -import itertools from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType, ModelPropertyKey -from core.model_runtime.entities.model_entities import PriceType +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel @@ -28,11 +26,11 @@ logger = logging.getLogger(__name__) def batch_generator(generator, batch_size): - while True: - batch = list(itertools.islice(generator, batch_size)) - if not batch: - break - yield batch + while True: + batch = list(itertools.islice(generator, batch_size)) + if not batch: + break + yield batch class SageMakerEmbeddingModel(TextEmbeddingModel): """ @@ -74,9 +72,9 @@ def _invoke(self, model: str, credentials: dict, try: line = 1 if not self.sagemaker_client: - access_key = credentials.get('aws_access_key_id', None) - secret_key = credentials.get('aws_secret_access_key', None) - aws_region = credentials.get('aws_region', None) + access_key = credentials.get('aws_access_key_id') + secret_key = credentials.get('aws_secret_access_key') + aws_region = credentials.get('aws_region') if aws_region: if access_key and secret_key: self.sagemaker_client = boto3.client("sagemaker-runtime", @@ -89,7 +87,7 @@ def _invoke(self, model: str, credentials: dict, self.sagemaker_client = boto3.client("sagemaker-runtime") line = 2 - sagemaker_endpoint = credentials.get('sagemaker_endpoint', None) + sagemaker_endpoint = credentials.get('sagemaker_endpoint') line = 3 truncated_texts = [ item[:CONTEXT_SIZE] for item in texts ] From 6861cdfff48c62121569654c532667658db50217 Mon Sep 17 00:00:00 2001 From: crazywoola <427733928@qq.com> Date: Thu, 18 Jul 2024 19:27:59 +0800 Subject: [PATCH 05/10] fix: lint --- api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml | 1 - 1 file changed, 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml index dbcc3f4c265a38..290cb0edabee09 100644 --- a/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml +++ b/api/core/model_runtime/model_providers/sagemaker/sagemaker.yaml @@ -123,4 +123,3 @@ model_credential_schema: label: en_US: AWS Ningxia (cn-northwest-1) zh_Hans: 中国宁夏 (cn-northwest-1) - From 397b932a94e0565b5f98f6a1759e0248931ea5df Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Fri, 26 Jul 2024 18:56:19 +0800 Subject: [PATCH 06/10] add aws builtin tools --- .../provider/builtin/aws/_assets/icon.svg | 9 ++ api/core/tools/provider/builtin/aws/aws.py | 25 ++++ api/core/tools/provider/builtin/aws/aws.yaml | 15 ++ .../builtin/aws/tools/apply_guardrail.py | 84 +++++++++++ .../builtin/aws/tools/apply_guardrail.yaml | 56 ++++++++ .../aws/tools/lambda_translate_utils.py | 87 ++++++++++++ .../aws/tools/lambda_translate_utils.yaml | 134 ++++++++++++++++++ .../aws/tools/sagemaker_text_rerank.py | 85 +++++++++++ .../aws/tools/sagemaker_text_rerank.yaml | 82 +++++++++++ 9 files changed, 577 insertions(+) create mode 100644 api/core/tools/provider/builtin/aws/_assets/icon.svg create mode 100644 api/core/tools/provider/builtin/aws/aws.py create mode 100644 api/core/tools/provider/builtin/aws/aws.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/apply_guardrail.py create mode 100644 api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py create mode 100644 api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml create mode 100644 api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py create mode 100644 api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml diff --git a/api/core/tools/provider/builtin/aws/_assets/icon.svg b/api/core/tools/provider/builtin/aws/_assets/icon.svg new file mode 100644 index 00000000000000..ecfcfc08d4eeff --- /dev/null +++ b/api/core/tools/provider/builtin/aws/_assets/icon.svg @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/aws.py b/api/core/tools/provider/builtin/aws/aws.py new file mode 100644 index 00000000000000..13ede9601509f5 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/aws.py @@ -0,0 +1,25 @@ +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.provider.builtin.aws.tools.sagemaker_text_rerank import SageMakerReRankTool +from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController + + +class SageMakerProvider(BuiltinToolProviderController): + def _validate_credentials(self, credentials: dict) -> None: + try: + SageMakerReRankTool().fork_tool_runtime( + runtime={ + "credentials": credentials, + } + ).invoke( + user_id='', + tool_parameters={ + "sagemaker_endpoint" : "", + "query": "misaka mikoto", + "candidate_texts" : "hello$$$hello world", + "topk" : 5, + "aws_region" : "" + }, + ) + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/aws.yaml b/api/core/tools/provider/builtin/aws/aws.yaml new file mode 100644 index 00000000000000..203214abf81869 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/aws.yaml @@ -0,0 +1,15 @@ +identity: + author: AWS + name: aws + label: + en_US: AWS + zh_Hans: 亚马逊云科技 + pt_BR: AWS + description: + en_US: Services on AWS. + zh_Hans: 亚马逊云科技的各类服务 + pt_BR: Services on AWS. + icon: icon.svg + tags: + - search +credentials_for_provider: \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py new file mode 100644 index 00000000000000..658a9753bb5e8f --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -0,0 +1,84 @@ +import boto3 +import json +import logging +from typing import Any, Dict, Union, List +from pydantic import BaseModel, Field + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +class GuardrailParameters(BaseModel): + guardrail_id: str = Field(..., description="The identifier of the guardrail") + guardrail_version: str = Field(..., description="The version of the guardrail") + source: str = Field(..., description="The source of the content") + text: str = Field(..., description="The text to apply the guardrail to") + aws_region: str = Field(default="us-east-1", description="AWS region for the Bedrock client") + +class ApplyGuardrailTool(BuiltinTool): + def _invoke(self, + user_id: str, + tool_parameters: Dict[str, Any] + ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + """ + Invoke the ApplyGuardrail tool + """ + try: + # Validate and parse input parameters + params = GuardrailParameters(**tool_parameters) + + # Initialize AWS client + bedrock_client = boto3.client('bedrock-runtime', region_name=params.aws_region) + + # Apply guardrail + response = bedrock_client.apply_guardrail( + guardrailIdentifier=params.guardrail_id, + guardrailVersion=params.guardrail_version, + source=params.source, + content=[{"text": {"text": params.text}}] + ) + + logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") + + # Check for empty response + if not response: + return self.create_text_message(text="Received empty response from AWS Bedrock.") + + # Process the result + action = response.get("action", "No action specified") + outputs = response.get("outputs", []) + output = outputs[0].get("text", "No output received") if outputs else "No output received" + assessments = response.get("assessments", []) + + # Format assessments + formatted_assessments = [] + for assessment in assessments: + for policy_type, policy_data in assessment.items(): + if isinstance(policy_data, dict) and 'topics' in policy_data: + for topic in policy_data['topics']: + formatted_assessments.append(f"Policy: {policy_type}, Topic: {topic['name']}, Type: {topic['type']}, Action: {topic['action']}") + else: + formatted_assessments.append(f"Policy: {policy_type}, Data: {policy_data}") + + result = f"Action: {action}\n " + result += f"Output: {output}\n " + if formatted_assessments: + result += "Assessments:\n " + "\n ".join(formatted_assessments) + "\n " +# result += f"Full response: {json.dumps(response, indent=2, ensure_ascii=False)}" + + return self.create_text_message(text=result) + + except boto3.exceptions.BotoCoreError as e: + error_message = f'AWS service error: {str(e)}' + logger.error(error_message, exc_info=True) + return self.create_text_message(text=error_message) + except json.JSONDecodeError as e: + error_message = f'JSON parsing error: {str(e)}' + logger.error(error_message, exc_info=True) + return self.create_text_message(text=error_message) + except Exception as e: + error_message = f'An unexpected error occurred: {str(e)}' + logger.error(error_message, exc_info=True) + return self.create_text_message(text=error_message) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml new file mode 100644 index 00000000000000..83e2ea05afd8de --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml @@ -0,0 +1,56 @@ +identity: + name: apply_guardrail + author: AWS + label: + en_US: Content Moderation Guardrails + zh_Hans: 内容审查护栏 +description: + human: + en_US: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation. + zh_Hans: 内容审查护栏采用 Guardrails for Amazon Bedrock 功能中的 ApplyGuardrail API 。ApplyGuardrail 可以评估所有基础模型(FMs)的输入提示和模型响应,包括 Amazon Bedrock 上的 FMs、自定义 FMs 和第三方 FMs。通过实施这一功能, 组织可以在所有生成式 AI 应用程序中实现集中化的治理,从而增强内容审核的控制力和一致性。 + llm: Content Moderation Guardrails utilizes the ApplyGuardrail API, a feature of Guardrails for Amazon Bedrock. This API is capable of evaluating input prompts and model responses for all Foundation Models (FMs), including those on Amazon Bedrock, custom FMs, and third-party FMs. By implementing this functionality, organizations can achieve centralized governance across all their generative AI applications, thereby enhancing control and consistency in content moderation. +parameters: + - name: guardrail_id + type: string + required: true + label: + en_US: Guardrail ID + zh_Hans: Guardrail ID + human_description: + en_US: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'. + zh_Hans: 请输入已经在 Amazon Bedrock 上创建好的 Guardrail ID, 例如 'qk5nk0e4b77b'. + llm_description: Please enter the ID of the Guardrail that has already been created on Amazon Bedrock, for example 'qk5nk0e4b77b'. + form: form + - name: guardrail_version + type: string + required: true + label: + en_US: Guardrail Version Number + zh_Hans: Guardrail 版本号码 + human_description: + en_US: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2. + zh_Hans: 请输入已经在Amazon Bedrock 上创建好的Guardrail ID发布的版本, 通常使用版本号, 例如2. + llm_description: Please enter the published version of the Guardrail ID that has already been created on Amazon Bedrock. This is typically a version number, such as 2. + form: form + - name: source + type: string + required: true + label: + en_US: Content Source (INPUT or OUTPUT) + zh_Hans: 内容来源 (INPUT or OUTPUT) + human_description: + en_US: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT" + zh_Hans: 用于应用护栏的请求中所使用的数据来源。有效值为 "INPUT | OUTPUT" + llm_description: The source of data used in the request to apply the guardrail. Valid Values "INPUT | OUTPUT" + form: form + - name: text + type: string + required: true + label: + en_US: Content to be reviewed + zh_Hans: 待审查内容 + human_description: + en_US: The content used for requesting guardrail review, which can be either user input or LLM output. + zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。 + llm_description: The content used for requesting guardrail review, which can be either user input or LLM output. + form: llm \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py new file mode 100644 index 00000000000000..cefbec17963760 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -0,0 +1,87 @@ +import boto3 +import json + +from typing import Any, Optional, Union, List +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class LambdaTranslateUtilsTool(BuiltinTool): + lambda_client: Any = None + + def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): + msg = { + "src_content":text_content, + "src_lang": src_lang, + "dest_lang":dest_lang, + "dictionary_id": dictionary_name, + "request_type" : request_type, + "model_id" : model_id + } + + invoke_response = self.lambda_client.invoke(FunctionName=lambda_name, + InvocationType='RequestResponse', + Payload=json.dumps(msg)) + response_body = invoke_response['Payload'] + + response_str = response_body.read().decode("unicode_escape") + + return response_str + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.lambda_client: + aws_region = tool_parameters.get('aws_region', None) + if aws_region: + self.lambda_client = boto3.client("lambda", region_name=aws_region) + else: + self.lambda_client = boto3.client("lambda") + + line = 1 + text_content = tool_parameters.get('text_content', '') + if not text_content: + return self.create_text_message('Please input text_content') + + line = 2 + src_lang = tool_parameters.get('src_lang', '') + if not src_lang: + return self.create_text_message('Please input src_lang') + + line = 3 + dest_lang = tool_parameters.get('dest_lang', '') + if not dest_lang: + return self.create_text_message('Please input dest_lang') + + line = 4 + lambda_name = tool_parameters.get('lambda_name', '') + if not lambda_name: + return self.create_text_message('Please input lambda_name') + + line = 5 + request_type = tool_parameters.get('request_type', '') + if not request_type: + return self.create_text_message('Please input request_type') + + line = 6 + model_id = tool_parameters.get('model_id', '') + if not model_id: + return self.create_text_message('Please input model_id') + + line = 7 + dictionary_name = tool_parameters.get('dictionary_name', '') + if not dictionary_name: + return self.create_text_message('Please input dictionary_name') + + result = self._invoke_lambda(text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name) + + return self.create_text_message(text=result) + + except Exception as e: + return self.create_text_message(f'Exception {str(e)}, line : {line}') diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml new file mode 100644 index 00000000000000..ac97acdca9d8e6 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -0,0 +1,134 @@ +identity: + name: lambda_translate_utils + author: AWS + label: + en_US: TranslateTool + zh_Hans: 翻译工具 + pt_BR: TranslateTool + icon: icon.svg +description: + human: + en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag + pt_BR: A util tools for LLM translation, specfic Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + llm: A util tools for translation. +parameters: + - name: text_content + type: string + required: true + label: + en_US: source content for translation + zh_Hans: 待翻译原文 + pt_BR: source content for translation + human_description: + en_US: source content for translation + zh_Hans: 待翻译原文 + pt_BR: source content for translation + llm_description: source content for translation + form: llm + - name: src_lang + type: string + required: true + label: + en_US: source language code + zh_Hans: 原文语言代号 + pt_BR: source language code + human_description: + en_US: source language code + zh_Hans: 原文语言代号 + pt_BR: source language code + llm_description: source language code + form: llm + - name: dest_lang + type: string + required: true + label: + en_US: target language code + zh_Hans: 目标语言代号 + pt_BR: target language code + human_description: + en_US: target language code + zh_Hans: 目标语言代号 + pt_BR: target language code + llm_description: target language code + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of Lambda + zh_Hans: Lambda 所在的region + pt_BR: region of Lambda + human_description: + en_US: region of Lambda + zh_Hans: Lambda 所在的region + pt_BR: region of Lambda + llm_description: region of Lambda + form: form + - name: model_id + type: string + required: false + default: anthropic.claude-3-sonnet-20240229-v1:0 + label: + en_US: LLM model_id in bedrock + zh_Hans: bedrock上的大语言模型model_id + pt_BR: LLM model_id in bedrock + human_description: + en_US: LLM model_id in bedrock + zh_Hans: bedrock上的大语言模型model_id + pt_BR: LLM model_id in bedrock + llm_description: LLM model_id in bedrock + form: form + - name: dictionary_name + type: string + required: false + label: + en_US: dictionary name for term mapping + zh_Hans: 专词映射表名称 + pt_BR: dictionary name for term mapping + human_description: + en_US: dictionary name for term mapping + zh_Hans: 专词映射表名称 + pt_BR: dictionary name for term mapping + llm_description: dictionary name for term mapping + form: form + - name: request_type + type: select + required: false + label: + en_US: request type + zh_Hans: 请求类型 + pt_BR: request type + human_description: + en_US: request type + zh_Hans: 请求类型 + pt_BR: request type + default: term_mapping + options: + - value: term_mapping + label: + en_US: term_mapping + zh_Hans: 专词映射 + - value: segment_only + label: + en_US: segment_only + zh_Hans: 仅切词 + - value: translate + label: + en_US: translate + zh_Hans: 翻译内容 + form: form + - name: lambda_name + type: string + default: "translate_tool" + required: true + label: + en_US: AWS Lambda for term mapping retrieval + zh_Hans: 专词召回映射 - AWS Lambda + pt_BR: lambda name for term mapping retrieval + human_description: + en_US: AWS Lambda for term mapping retrieval + zh_Hans: 专词召回映射 - AWS Lambda + pt_BR: AWS Lambda for term mapping retrieval + llm_description: AWS Lambda for term mapping retrieval + form: form \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py new file mode 100644 index 00000000000000..b54098ab6daa9c --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -0,0 +1,85 @@ +import boto3 +import json + +from typing import Any, Optional, Union, List +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + + +class SageMakerReRankTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint:str = None + topk:int = None + + def _sagemaker_rerank(self, query_input: str, docs: List[str], rerank_endpoint:str): + inputs = [query_input]*len(docs) + response_model = self.sagemaker_client.invoke_endpoint( + EndpointName=rerank_endpoint, + Body=json.dumps( + { + "inputs": inputs, + "docs": docs + } + ), + ContentType="application/json", + ) + json_str = response_model['Body'].read().decode('utf8') + json_obj = json.loads(json_str) + scores = json_obj['scores'] + return scores if isinstance(scores, list) else [scores] + + def _invoke(self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + line = 0 + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get('aws_region', None) + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + line = 1 + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint', None) + + line = 2 + if not self.topk: + self.topk = tool_parameters.get('topk', 5) + + line = 3 + query = tool_parameters.get('query', '') + if not query: + return self.create_text_message('Please input query') + + line = 4 + candidate_texts = tool_parameters.get('candidate_texts', None) + if not candidate_texts: + return self.create_text_message('Please input candidate_texts') + + line = 5 + candidate_docs = json.loads(candidate_texts) + docs = [ item.get('content', None) for item in candidate_docs ] + + line = 6 + scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) + + line = 7 + for idx in range(len(candidate_docs)): + candidate_docs[idx]["score"] = scores[idx] + + line = 8 + sorted_candidate_docs = sorted(candidate_docs, key=lambda x: x['score'], reverse=True) + + line = 9 + results_str = json.dumps(sorted_candidate_docs[:self.topk], ensure_ascii=False) + return self.create_text_message(text=results_str) + + except Exception as e: + return self.create_text_message(f'Exception {str(e)}, line : {line}') + \ No newline at end of file diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml new file mode 100644 index 00000000000000..e527adf818ff79 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml @@ -0,0 +1,82 @@ +identity: + name: sagemaker_text_rerank + author: AWS + label: + en_US: SagemakerRerank + zh_Hans: Sagemaker重排序 + pt_BR: SagemakerRerank + icon: icon.svg +description: + human: + en_US: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool + zh_Hans: Sagemaker重排序工具, 请参考 Github Repo - https://github.com/aws-samples/dify-aws-tool上的部署脚本 + pt_BR: A tool for performing text similarity ranking. + llm: A tool for performing text similarity ranking. You can find deploy notebook on Github Repo - https://github.com/aws-samples/dify-aws-tool +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for reranking + zh_Hans: 重排序的SageMaker 端点 + pt_BR: sagemaker endpoint for reranking + human_description: + en_US: sagemaker endpoint for reranking + zh_Hans: 重排序的SageMaker 端点 + pt_BR: sagemaker endpoint for reranking + llm_description: sagemaker endpoint for reranking + form: form + - name: query + type: string + required: true + label: + en_US: Query string + zh_Hans: 查询语句 + pt_BR: Query string + human_description: + en_US: key words for searching + zh_Hans: 查询关键词 + pt_BR: key words for searching + llm_description: key words for searching + form: llm + - name: candidate_texts + type: string + required: true + label: + en_US: text candidates + zh_Hans: 候选文本 + pt_BR: text candidates + human_description: + en_US: searched candidates by query + zh_Hans: 查询文本搜到候选文本 + pt_BR: searched candidates by query + llm_description: searched candidates by query + form: llm + - name: topk + type: number + required: false + form: form + label: + en_US: Limit for results count + zh_Hans: 返回个数限制 + pt_BR: Limit for results count + human_description: + en_US: Limit for results count + zh_Hans: 返回个数限制 + pt_BR: Limit for results count + min: 1 + max: 10 + default: 5 + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + pt_BR: region of sagemaker endpoint + llm_description: region of sagemaker endpoint + form: form \ No newline at end of file From f634ca3a465232fe9472490baab3a5e1a2ccffa2 Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Fri, 26 Jul 2024 19:17:53 +0800 Subject: [PATCH 07/10] fix lint error --- .../provider/builtin/aws/tools/apply_guardrail.py | 6 +++--- .../builtin/aws/tools/lambda_translate_utils.py | 4 ++-- .../builtin/aws/tools/sagemaker_text_rerank.py | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 658a9753bb5e8f..1bee041588e5e3 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -1,7 +1,7 @@ import boto3 import json import logging -from typing import Any, Dict, Union, List +from typing import Any, Union from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage @@ -20,8 +20,8 @@ class GuardrailParameters(BaseModel): class ApplyGuardrailTool(BuiltinTool): def _invoke(self, user_id: str, - tool_parameters: Dict[str, Any] - ) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]: + tool_parameters: dict[str, Any] + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: """ Invoke the ApplyGuardrail tool """ diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index cefbec17963760..280a6909604af2 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,7 @@ import boto3 import json -from typing import Any, Optional, Union, List +from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -38,7 +38,7 @@ def _invoke(self, line = 0 try: if not self.lambda_client: - aws_region = tool_parameters.get('aws_region', None) + aws_region = tool_parameters.get('aws_region') if aws_region: self.lambda_client = boto3.client("lambda", region_name=aws_region) else: diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index b54098ab6daa9c..4c115767ebfa88 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -1,7 +1,7 @@ import boto3 import json -from typing import Any, Optional, Union, List +from typing import Any, Union from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool @@ -11,7 +11,7 @@ class SageMakerReRankTool(BuiltinTool): sagemaker_endpoint:str = None topk:int = None - def _sagemaker_rerank(self, query_input: str, docs: List[str], rerank_endpoint:str): + def _sagemaker_rerank(self, query_input: str, docs: list[str], rerank_endpoint:str): inputs = [query_input]*len(docs) response_model = self.sagemaker_client.invoke_endpoint( EndpointName=rerank_endpoint, @@ -38,7 +38,7 @@ def _invoke(self, line = 0 try: if not self.sagemaker_client: - aws_region = tool_parameters.get('aws_region', None) + aws_region = tool_parameters.get('aws_region') if aws_region: self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) else: @@ -46,7 +46,7 @@ def _invoke(self, line = 1 if not self.sagemaker_endpoint: - self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint', None) + self.sagemaker_endpoint = tool_parameters.get('sagemaker_endpoint') line = 2 if not self.topk: @@ -58,13 +58,13 @@ def _invoke(self, return self.create_text_message('Please input query') line = 4 - candidate_texts = tool_parameters.get('candidate_texts', None) + candidate_texts = tool_parameters.get('candidate_texts') if not candidate_texts: return self.create_text_message('Please input candidate_texts') line = 5 candidate_docs = json.loads(candidate_texts) - docs = [ item.get('content', None) for item in candidate_docs ] + docs = [ item.get('content') for item in candidate_docs ] line = 6 scores = self._sagemaker_rerank(query_input=query, docs=docs, rerank_endpoint=self.sagemaker_endpoint) From 8d48ae3d579d104edacb4df4d5262c142110367b Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Fri, 26 Jul 2024 19:22:39 +0800 Subject: [PATCH 08/10] fix lint error2 --- api/core/tools/provider/builtin/aws/tools/apply_guardrail.py | 3 ++- .../provider/builtin/aws/tools/lambda_translate_utils.py | 5 +++-- .../provider/builtin/aws/tools/sagemaker_text_rerank.py | 5 +++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index 1bee041588e5e3..ee2dce0e031a9d 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -1,7 +1,8 @@ -import boto3 import json import logging from typing import Any, Union + +import boto3 from pydantic import BaseModel, Field from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 280a6909604af2..005ba3deb53311 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -1,7 +1,8 @@ -import boto3 import json - from typing import Any, Union + +import boto3 + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py index 4c115767ebfa88..d4bc446e5b13d8 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.py @@ -1,7 +1,8 @@ -import boto3 import json - from typing import Any, Union + +import boto3 + from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool From 50b7e4d00ff043613e48b5e059b63111a6396ab9 Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Fri, 26 Jul 2024 19:28:41 +0800 Subject: [PATCH 09/10] Fix supper lint error --- api/core/tools/provider/builtin/aws/aws.yaml | 2 +- api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml | 2 +- .../provider/builtin/aws/tools/lambda_translate_utils.yaml | 2 +- .../tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/aws.yaml b/api/core/tools/provider/builtin/aws/aws.yaml index 203214abf81869..847c6824a53df6 100644 --- a/api/core/tools/provider/builtin/aws/aws.yaml +++ b/api/core/tools/provider/builtin/aws/aws.yaml @@ -12,4 +12,4 @@ identity: icon: icon.svg tags: - search -credentials_for_provider: \ No newline at end of file +credentials_for_provider: diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml index 83e2ea05afd8de..2b7c8abb442f77 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.yaml @@ -53,4 +53,4 @@ parameters: en_US: The content used for requesting guardrail review, which can be either user input or LLM output. zh_Hans: 用于请求护栏审查的内容,可以是用户输入或 LLM 输出。 llm_description: The content used for requesting guardrail review, which can be either user input or LLM output. - form: llm \ No newline at end of file + form: llm diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml index ac97acdca9d8e6..a35c9f49fb9720 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -131,4 +131,4 @@ parameters: zh_Hans: 专词召回映射 - AWS Lambda pt_BR: AWS Lambda for term mapping retrieval llm_description: AWS Lambda for term mapping retrieval - form: form \ No newline at end of file + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml index e527adf818ff79..d1dfdb9f84a858 100644 --- a/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_text_rerank.yaml @@ -79,4 +79,4 @@ parameters: zh_Hans: SageMaker 端点所在的region pt_BR: region of sagemaker endpoint llm_description: region of sagemaker endpoint - form: form \ No newline at end of file + form: form From 15b14f13f3d462da55219ec4cb08ff6050f4a275 Mon Sep 17 00:00:00 2001 From: Yuanbo Li Date: Wed, 31 Jul 2024 12:47:10 +0800 Subject: [PATCH 10/10] remove useless log --- api/core/tools/provider/builtin/aws/tools/apply_guardrail.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py index ee2dce0e031a9d..9c006733bdd95d 100644 --- a/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py +++ b/api/core/tools/provider/builtin/aws/tools/apply_guardrail.py @@ -40,8 +40,6 @@ def _invoke(self, source=params.source, content=[{"text": {"text": params.text}}] ) - - logger.info(f"Raw response from AWS: {json.dumps(response, indent=2)}") # Check for empty response if not response: