diff --git a/.github/workflows/leaderboard_build.yml b/.github/workflows/leaderboard_build.yml index 90d52c70a2..2aa3b9089a 100644 --- a/.github/workflows/leaderboard_build.yml +++ b/.github/workflows/leaderboard_build.yml @@ -22,7 +22,7 @@ jobs: - name: Install dependencies (incl. leaderboard extra) run: | - pip install ".[dev,leaderboard]" + pip install ".[leaderboard]" --group dev - name: Run leaderboard build test run: | diff --git a/Makefile b/Makefile index 9428427158..11e0c85da1 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ install: @echo "--- ๐Ÿš€ Installing project dependencies ---" - pip install -e ".[dev,image]" + pip install -e ".[image]" --group dev pre-commit install install-for-tests: @echo "--- ๐Ÿš€ Installing project dependencies for test ---" @echo "This ensures that the project is not installed in editable mode" - pip install ".[dev,image]" + pip install ".[image]" --group dev lint: @echo "--- ๐Ÿงน Running linters ---" @@ -17,7 +17,7 @@ lint-check: @echo "--- ๐Ÿงน Check is project is linted ---" # Required for CI to work, otherwise it will just pass ruff format . --check # running ruff formatting - ruff check **/*.py # running ruff linting + ruff check . # running ruff linting test: @echo "--- ๐Ÿงช Running tests ---" @@ -43,7 +43,7 @@ build-docs: model-load-test: @echo "--- ๐Ÿš€ Running model load test ---" - pip install ".[dev, pylate,gritlm,xformers,model2vec]" + pip install ".[pylate,gritlm,xformers,model2vec]" --group dev python scripts/extract_model_names.py $(BASE_BRANCH) --return_one_model_name_per_file python tests/test_models/model_loading.py --model_name_file scripts/model_names.txt diff --git a/docs/benchmarks.md b/docs/benchmarks.md index 032113c401..202fde6825 100644 --- a/docs/benchmarks.md +++ b/docs/benchmarks.md @@ -30,6 +30,7 @@ The following table gives you an overview of the benchmarks in MTEB. | [MTEB(Indic, v1)](https://arxiv.org/abs/2502.13595) | Indic | 23 | BitextMining: 4, Clustering: 1, Classification: 13, PairClassification: 1, Retrieval: 2, Reranking: 1, STS: 1 | [Constructed, Encyclopaedic, Fiction, Government, Legal, News, Non-fiction, Religious, Reviews, Social, Spoken, Web, Written] | asm,awa,ben,bgc,bho,bod,boy,brx,doi,eng,gbm,gom,guj,hin,hne,kan,kas,mai,mal,mar,mni,mup,mwr,nep,npi,ory,pan,pus,raj,san,sat,snd,tam,tel,urd | | MTEB(Law, v1) | Legal | 8 | Retrieval: 8 | [Legal, Written] | deu,eng,zho | | MTEB(Medical, v1) | Medical | 12 | Retrieval: 9, Clustering: 2, Reranking: 1 | [Academic, Government, Medical, Non-fiction, Web, Written] | ara,cmn,eng,fra,kor,pol,rus,spa,vie,zho | +| [MTEB(Multilingual, v1)](https://arxiv.org/abs/2502.13595) | Multilingual | 132 | BitextMining: 13, Classification: 43, Clustering: 17, Retrieval: 18, InstructionRetrieval: 3, MultilabelClassification: 5, PairClassification: 11, Reranking: 6, STS: 16 | [Academic, Blog, Constructed, Encyclopaedic, Entertainment, Fiction, Financial, Government, Legal, Medical, News, Non-fiction, Programming, Religious, Reviews, Social, Spoken, Subtitles, Web, Written] | aai,aak,aau,aaz,abs,abt,abx,aby,ace,acf,acm,acq,acr,acu,adz,aeb,aer,aey,afr,agd,agg,agm,agn,agr,agt,agu,aia,aii,ajp,aka,ake,alp,alq,als,aly,ame,amf,amh,amk,amm,amn,amo,amp,amr,amu,amx,ang,anh,anv,aoi,aoj,aom,aon,apb,apc,ape,apn,apr,apu,apw,apz,ara,arb,are,arl,arn,arp,arq,ars,ary,arz,asm,aso,ast,ata,atb,atd,atg,att,auc,aui,auy,avt,awa,awb,awk,awx,ayr,azb,aze,azg,azj,azz,bak,bam,ban,bao,bba,bbb,bbc,bbr,bch,bco,bdd,bea,bef,bel,bem,ben,beo,ber,beu,bew,bgc,bgs,bgt,bhg,bhl,bho,bhp,big,bjk,bjn,bjp,bjr,bjv,bjz,bkd,bki,bkq,bkx,blw,blz,bmh,bmk,bmr,bmu,bnp,boa,bod,boj,bon,bos,box,boy,bpr,bps,bqc,bqp,bre,brx,bsj,bsn,bsp,bss,bug,buk,bul,bus,bvd,bvr,bxh,byr,byx,bzd,bzh,bzj,caa,cab,cac,caf,cak,cao,cap,car,cat,cav,cax,cbc,cbi,cbk,cbr,cbs,cbt,cbu,cbv,cco,ceb,cek,ces,cgc,cha,chd,chf,chk,chq,chv,chz,cjk,cjo,cjv,ckb,cle,clu,cme,cmn,cmo,cni,cnl,cnt,cof,con,cop,cor,cot,cpa,cpb,cpc,cpu,cpy,crh,crn,crx,csb,cso,csy,cta,cth,ctp,ctu,cub,cuc,cui,cuk,cut,cux,cwe,cya,cym,daa,dad,dah,dan,ded,deu,dgc,dgr,dgz,dhg,dif,dik,div,dji,djk,djr,dob,doi,dop,dov,dsb,dtp,dwr,dww,dwy,dyu,dzo,ebk,eko,ell,emi,emp,eng,enq,epo,eri,ese,esk,est,etr,eus,ewe,faa,fai,fao,far,fas,ffm,fij,fil,fin,fon,for,fra,fry,fuc,fue,fuf,fuh,fur,fuv,gah,gai,gam,gaw,gaz,gbm,gdn,gdr,geb,gfk,ghs,gla,gle,glg,glk,glv,gmv,gng,gnn,gnw,gof,gom,grc,grn,gsw,gub,guh,gui,guj,gul,gum,gun,guo,gup,gux,gvc,gvf,gvn,gvs,gwi,gym,gyr,hat,hau,haw,hbo,hch,heb,heg,hin,hix,hla,hlt,hmn,hmo,hne,hns,hop,hot,hrv,hsb,hto,hub,hui,hun,hus,huu,huv,hvn,hye,ian,ibo,ido,ign,ikk,ikw,ile,ilo,imo,ina,inb,ind,ino,iou,ipi,isl,isn,ita,iws,ixl,jac,jae,jao,jav,jic,jid,jiv,jni,jpn,jvn,kab,kac,kam,kan,kaq,kas,kat,kaz,kbc,kbh,kbm,kbp,kbq,kdc,kde,kdl,kea,kek,ken,kew,kgf,kgk,kgp,khk,khm,khs,khz,kik,kin,kir,kiw,kiz,kje,kjs,kkc,kkl,klt,klv,kmb,kmg,kmh,kmk,kmo,kmr,kms,kmu,knc,kne,knf,knj,knv,kon,kor,kos,kpf,kpg,kpj,kpr,kpw,kpx,kqa,kqc,kqf,kql,kqw,krc,ksd,ksj,ksr,ktm,kto,kud,kue,kup,kur,kvg,kvn,kwd,kwf,kwi,kwj,kyc,kyf,kyg,kyq,kyz,kze,kzj,lac,lao,lat,lav,lbb,lbk,lcm,leu,lex,lfn,lgl,lid,lif,lij,lim,lin,lit,llg,lmo,ltg,ltz,lua,lug,luo,lus,lvs,lww,maa,mad,mag,mai,maj,mak,mal,mam,maq,mar,mau,mav,max,maz,mbb,mbc,mbh,mbj,mbl,mbs,mbt,mca,mcb,mcd,mcf,mco,mcp,mcq,mcr,mdy,med,mee,mek,meq,met,meu,mey,mgc,mgh,mgw,mhl,mhr,mib,mic,mie,mig,mih,mil,min,mio,mir,mit,miz,mjc,mkd,mkj,mkl,mkn,mks,mle,mlg,mlh,mlp,mlt,mmo,mmx,mna,mni,mon,mop,mos,mox,mph,mpj,mpm,mpp,mps,mpt,mpx,mqb,mqj,mri,msa,msb,msc,msk,msm,msy,mti,mto,mui,mup,mux,muy,mva,mvn,mwc,mwe,mwf,mwp,mwr,mxb,mxp,mxq,mxt,mya,myk,myu,myw,myy,mzz,nab,naf,nak,nas,nbq,nca,nch,ncj,ncl,ncu,nde,ndg,ndj,nds,nep,nfa,ngp,ngu,nhe,nhg,nhi,nho,nhr,nhu,nhw,nhy,nif,nii,nij,nin,nko,nld,nlg,nna,nno,nnq,noa,nob,nop,nor,not,nou,nov,npi,npl,nqo,nsn,nso,nss,ntj,ntp,ntu,nus,nuy,nvm,nwi,nya,nys,nyu,obo,oci,okv,omw,ong,ons,ood,opm,orm,orv,ory,ote,otm,otn,otq,ots,pab,pad,pag,pah,pam,pan,pao,pap,pbt,pcm,pes,pib,pio,pir,piu,pjt,pls,plt,plu,pma,pms,poe,poh,poi,pol,pon,por,poy,ppo,prf,pri,prs,ptp,ptu,pus,pwg,qub,quc,quf,quh,qul,qup,quy,qvc,qve,qvh,qvm,qvn,qvs,qvw,qvz,qwh,qxh,qxn,qxo,rai,raj,reg,rej,rgu,rkb,rmc,rmy,rom,ron,roo,rop,row,rro,ruf,rug,run,rus,rwo,sab,sag,sah,san,sat,sbe,sbk,sbs,scn,sco,seh,sey,sgb,sgz,shi,shj,shn,shp,sim,sin,sja,slk,sll,slv,smk,smo,sna,snc,snd,snn,snp,snx,sny,som,soq,sot,soy,spa,spl,spm,spp,sps,spy,sqi,srd,sri,srm,srn,srp,srq,ssd,ssg,ssw,ssx,stp,sua,sue,sun,sus,suz,svk,swa,swe,swg,swh,swp,sxb,szl,tac,tah,taj,tam,taq,tat,tav,taw,tbc,tbf,tbg,tbo,tbz,tca,tcs,tcz,tdt,tee,tel,ter,tet,tew,tfr,tgk,tgl,tgo,tgp,tha,tif,tim,tir,tiw,tiy,tke,tku,tlf,tmd,tna,tnc,tnk,tnn,tnp,toc,tod,tof,toj,ton,too,top,tos,tpa,tpi,tpt,tpz,trc,tsn,tso,tsw,ttc,tte,tuc,tue,tuf,tuk,tum,tuo,tur,tvk,twi,txq,txu,tyv,tzj,tzl,tzm,tzo,ubr,ubu,udu,uig,ukr,uli,ulk,umb,upv,ura,urb,urd,uri,urt,urw,usa,usp,uvh,uvl,uzb,uzn,vec,ven,vid,vie,viv,vmy,waj,wal,wap,war,wat,wbi,wbp,wed,wer,wim,wiu,wiv,wln,wmt,wmw,wnc,wnu,wol,wos,wrk,wro,wrs,wsk,wuu,wuv,xav,xbi,xed,xho,xla,xnn,xon,xsi,xtd,xtm,yaa,yad,yal,yap,yaq,yby,ycn,ydd,yid,yka,yle,yml,yon,yor,yrb,yre,yss,yue,yuj,yut,yuw,yva,zaa,zab,zac,zad,zai,zaj,zam,zao,zap,zar,zas,zat,zav,zaw,zca,zga,zho,zia,ziw,zlm,zos,zpc,zpl,zpm,zpo,zpq,zpu,zpv,zpz,zsm,zsr,ztq,zty,zul,zyp | | [MTEB(Multilingual, v2)](https://arxiv.org/abs/2502.13595) | Multilingual | 131 | BitextMining: 13, Classification: 43, Clustering: 16, Retrieval: 18, InstructionRetrieval: 3, MultilabelClassification: 5, PairClassification: 11, Reranking: 6, STS: 16 | [Academic, Blog, Constructed, Encyclopaedic, Entertainment, Fiction, Financial, Government, Legal, Medical, News, Non-fiction, Programming, Religious, Reviews, Social, Spoken, Subtitles, Web, Written] | aai,aak,aau,aaz,abs,abt,abx,aby,ace,acf,acm,acq,acr,acu,adz,aeb,aer,aey,afr,agd,agg,agm,agn,agr,agt,agu,aia,aii,ajp,aka,ake,alp,alq,als,aly,ame,amf,amh,amk,amm,amn,amo,amp,amr,amu,amx,ang,anh,anv,aoi,aoj,aom,aon,apb,apc,ape,apn,apr,apu,apw,apz,ara,arb,are,arl,arn,arp,arq,ars,ary,arz,asm,aso,ast,ata,atb,atd,atg,att,auc,aui,auy,avt,awa,awb,awk,awx,ayr,azb,aze,azg,azj,azz,bak,bam,ban,bao,bba,bbb,bbc,bbr,bch,bco,bdd,bea,bef,bel,bem,ben,beo,ber,beu,bew,bgc,bgs,bgt,bhg,bhl,bho,bhp,big,bjk,bjn,bjp,bjr,bjv,bjz,bkd,bki,bkq,bkx,blw,blz,bmh,bmk,bmr,bmu,bnp,boa,bod,boj,bon,bos,box,boy,bpr,bps,bqc,bqp,bre,brx,bsj,bsn,bsp,bss,bug,buk,bul,bus,bvd,bvr,bxh,byr,byx,bzd,bzh,bzj,caa,cab,cac,caf,cak,cao,cap,car,cat,cav,cax,cbc,cbi,cbk,cbr,cbs,cbt,cbu,cbv,cco,ceb,cek,ces,cgc,cha,chd,chf,chk,chq,chv,chz,cjk,cjo,cjv,ckb,cle,clu,cme,cmn,cmo,cni,cnl,cnt,cof,con,cop,cor,cot,cpa,cpb,cpc,cpu,cpy,crh,crn,crx,csb,cso,csy,cta,cth,ctp,ctu,cub,cuc,cui,cuk,cut,cux,cwe,cya,cym,daa,dad,dah,dan,ded,deu,dgc,dgr,dgz,dhg,dif,dik,div,dji,djk,djr,dob,doi,dop,dov,dsb,dtp,dwr,dww,dwy,dyu,dzo,ebk,eko,ell,emi,emp,eng,enq,epo,eri,ese,esk,est,etr,eus,ewe,faa,fai,fao,far,fas,ffm,fij,fil,fin,fon,for,fra,fry,fuc,fue,fuf,fuh,fur,fuv,gah,gai,gam,gaw,gaz,gbm,gdn,gdr,geb,gfk,ghs,gla,gle,glg,glk,glv,gmv,gng,gnn,gnw,gof,gom,grc,grn,gsw,gub,guh,gui,guj,gul,gum,gun,guo,gup,gux,gvc,gvf,gvn,gvs,gwi,gym,gyr,hat,hau,haw,hbo,hch,heb,heg,hin,hix,hla,hlt,hmn,hmo,hne,hns,hop,hot,hrv,hsb,hto,hub,hui,hun,hus,huu,huv,hvn,hye,ian,ibo,ido,ign,ikk,ikw,ile,ilo,imo,ina,inb,ind,ino,iou,ipi,isl,isn,ita,iws,ixl,jac,jae,jao,jav,jic,jid,jiv,jni,jpn,jvn,kab,kac,kam,kan,kaq,kas,kat,kaz,kbc,kbh,kbm,kbp,kbq,kdc,kde,kdl,kea,kek,ken,kew,kgf,kgk,kgp,khk,khm,khs,khz,kik,kin,kir,kiw,kiz,kje,kjs,kkc,kkl,klt,klv,kmb,kmg,kmh,kmk,kmo,kmr,kms,kmu,knc,kne,knf,knj,knv,kon,kor,kos,kpf,kpg,kpj,kpr,kpw,kpx,kqa,kqc,kqf,kql,kqw,krc,ksd,ksj,ksr,ktm,kto,kud,kue,kup,kur,kvg,kvn,kwd,kwf,kwi,kwj,kyc,kyf,kyg,kyq,kyz,kze,kzj,lac,lao,lat,lav,lbb,lbk,lcm,leu,lex,lfn,lgl,lid,lif,lij,lim,lin,lit,llg,lmo,ltg,ltz,lua,lug,luo,lus,lvs,lww,maa,mad,mag,mai,maj,mak,mal,mam,maq,mar,mau,mav,max,maz,mbb,mbc,mbh,mbj,mbl,mbs,mbt,mca,mcb,mcd,mcf,mco,mcp,mcq,mcr,mdy,med,mee,mek,meq,met,meu,mey,mgc,mgh,mgw,mhl,mhr,mib,mic,mie,mig,mih,mil,min,mio,mir,mit,miz,mjc,mkd,mkj,mkl,mkn,mks,mle,mlg,mlh,mlp,mlt,mmo,mmx,mna,mni,mon,mop,mos,mox,mph,mpj,mpm,mpp,mps,mpt,mpx,mqb,mqj,mri,msa,msb,msc,msk,msm,msy,mti,mto,mui,mup,mux,muy,mva,mvn,mwc,mwe,mwf,mwp,mwr,mxb,mxp,mxq,mxt,mya,myk,myu,myw,myy,mzz,nab,naf,nak,nas,nbq,nca,nch,ncj,ncl,ncu,nde,ndg,ndj,nds,nep,nfa,ngp,ngu,nhe,nhg,nhi,nho,nhr,nhu,nhw,nhy,nif,nii,nij,nin,nko,nld,nlg,nna,nno,nnq,noa,nob,nop,nor,not,nou,nov,npi,npl,nqo,nsn,nso,nss,ntj,ntp,ntu,nus,nuy,nvm,nwi,nya,nys,nyu,obo,oci,okv,omw,ong,ons,ood,opm,orm,orv,ory,ote,otm,otn,otq,ots,pab,pad,pag,pah,pam,pan,pao,pap,pbt,pcm,pes,pib,pio,pir,piu,pjt,pls,plt,plu,pma,pms,poe,poh,poi,pol,pon,por,poy,ppo,prf,pri,prs,ptp,ptu,pus,pwg,qub,quc,quf,quh,qul,qup,quy,qvc,qve,qvh,qvm,qvn,qvs,qvw,qvz,qwh,qxh,qxn,qxo,rai,raj,reg,rej,rgu,rkb,rmc,rmy,rom,ron,roo,rop,row,rro,ruf,rug,run,rus,rwo,sab,sag,sah,san,sat,sbe,sbk,sbs,scn,sco,seh,sey,sgb,sgz,shi,shj,shn,shp,sim,sin,sja,slk,sll,slv,smk,smo,sna,snc,snd,snn,snp,snx,sny,som,soq,sot,soy,spa,spl,spm,spp,sps,spy,sqi,srd,sri,srm,srn,srp,srq,ssd,ssg,ssw,ssx,stp,sua,sue,sun,sus,suz,svk,swa,swe,swg,swh,swp,sxb,szl,tac,tah,taj,tam,taq,tat,tav,taw,tbc,tbf,tbg,tbo,tbz,tca,tcs,tcz,tdt,tee,tel,ter,tet,tew,tfr,tgk,tgl,tgo,tgp,tha,tif,tim,tir,tiw,tiy,tke,tku,tlf,tmd,tna,tnc,tnk,tnn,tnp,toc,tod,tof,toj,ton,too,top,tos,tpa,tpi,tpt,tpz,trc,tsn,tso,tsw,ttc,tte,tuc,tue,tuf,tuk,tum,tuo,tur,tvk,twi,txq,txu,tyv,tzj,tzl,tzm,tzo,ubr,ubu,udu,uig,ukr,uli,ulk,umb,upv,ura,urb,urd,uri,urt,urw,usa,usp,uvh,uvl,uzb,uzn,vec,ven,vid,vie,viv,vmy,waj,wal,wap,war,wat,wbi,wbp,wed,wer,wim,wiu,wiv,wln,wmt,wmw,wnc,wnu,wol,wos,wrk,wro,wrs,wsk,wuu,wuv,xav,xbi,xed,xho,xla,xnn,xon,xsi,xtd,xtm,yaa,yad,yal,yap,yaq,yby,ycn,ydd,yid,yka,yle,yml,yon,yor,yrb,yre,yss,yue,yuj,yut,yuw,yva,zaa,zab,zac,zad,zai,zaj,zam,zao,zap,zar,zas,zat,zav,zaw,zca,zga,zho,zia,ziw,zlm,zos,zpc,zpl,zpm,zpo,zpq,zpu,zpv,zpz,zsm,zsr,ztq,zty,zul,zyp | | [MTEB(Scandinavian, v1)](https://kennethenevoldsen.github.io/scandinavian-embedding-benchmark/) | Scandinavian | 28 | BitextMining: 2, Classification: 13, Retrieval: 7, Clustering: 6 | [Blog, Encyclopaedic, Fiction, Government, Legal, News, Non-fiction, Reviews, Social, Spoken, Web, Written] | dan,fao,isl,nno,nob,swe | | [MTEB(cmn, v1)](https://github.com/FlagOpen/FlagEmbedding/tree/master/research/C_MTEB) | Chinese | 32 | Retrieval: 8, Reranking: 4, PairClassification: 2, Clustering: 4, STS: 7, Classification: 7 | [Academic, Entertainment, Financial, Government, Medical, Non-fiction, Written] | cmn | diff --git a/mteb/abstasks/TaskMetadata.py b/mteb/abstasks/TaskMetadata.py index aa4658f513..28b796c1f1 100644 --- a/mteb/abstasks/TaskMetadata.py +++ b/mteb/abstasks/TaskMetadata.py @@ -235,16 +235,15 @@ METRIC_VALUE = Union[int, float, dict[str, Any]] -class PromptDict(TypedDict, total=False): - """A dictionary containing the prompt used for the task. - - Args: - query: The prompt used for the queries in the task. - passage: The prompt used for the passages in the task. - """ +PromptDict = TypedDict( + "PromptDict", {prompt_type.value: str for prompt_type in PromptType}, total=False +) +"""A dictionary containing the prompt used for the task. - query: str - passage: str +Args: + query: The prompt used for the queries in the task. + document: The prompt used for the passages in the task. +""" class DescriptiveStatistics(TypedDict): @@ -253,9 +252,6 @@ class DescriptiveStatistics(TypedDict): pass -METRIC_VALUE = Union[int, float, dict[str, Any]] - - logger = logging.getLogger(__name__) @@ -289,6 +285,7 @@ class TaskMetadata(BaseModel): prompt: The prompt used for the task. Can be a string or a dictionary containing the query and passage prompts. bibtex_citation: The BibTeX citation for the dataset. Should be an empty string if no citation is available. adapted_from: Datasets adapted (translated, sampled from, etc.) from other datasets. + is_public: Whether the dataset is publicly available. If False (closed/private), a HuggingFace token is required to run the datasets. """ dataset: dict[str, Any] @@ -316,6 +313,7 @@ class TaskMetadata(BaseModel): sample_creation: SAMPLE_CREATION_METHOD | None = None bibtex_citation: str | None = None adapted_from: list[str] | None = None + is_public: bool = True def validate_metadata(self) -> None: self.dataset_path_is_specified(self.dataset) @@ -323,6 +321,7 @@ def validate_metadata(self) -> None: self.eval_langs_are_valid(self.eval_langs) @field_validator("dataset") + @classmethod def _check_dataset_path_is_specified( cls, dataset: dict[str, Any] ) -> dict[str, Any]: @@ -330,6 +329,7 @@ def _check_dataset_path_is_specified( return dataset @field_validator("dataset") + @classmethod def _check_dataset_revision_is_specified( cls, dataset: dict[str, Any] ) -> dict[str, Any]: @@ -337,6 +337,7 @@ def _check_dataset_revision_is_specified( return dataset @field_validator("prompt") + @classmethod def _check_prompt_is_valid( cls, prompt: str | PromptDict | None ) -> str | PromptDict | None: @@ -344,7 +345,7 @@ def _check_prompt_is_valid( for key in prompt: if key not in [e.value for e in PromptType]: raise ValueError( - "The prompt dictionary should only contain the keys 'query' and 'passage'." + "The prompt dictionary should only contain the keys 'query' and 'document'." ) return prompt @@ -419,7 +420,7 @@ def is_filled(self) -> bool: return all( getattr(self, field_name) is not None for field_name in self.model_fields - if field_name not in ["prompt", "adapted_from"] + if field_name not in ["prompt", "adapted_from", "is_public"] ) @property diff --git a/mteb/benchmarks/benchmarks/__init__.py b/mteb/benchmarks/benchmarks/__init__.py new file mode 100644 index 0000000000..f4b34233db --- /dev/null +++ b/mteb/benchmarks/benchmarks/__init__.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from mteb.benchmarks.benchmark import Benchmark +from mteb.benchmarks.benchmarks.benchmarks import ( + BEIR, + BEIR_NL, + BRIGHT, + BRIGHT_LONG, + BUILT_MTEB, + C_MTEB, + CHEMTEB, + CODE_RAG, + ENCODECHKA, + FA_MTEB, + JINA_VDR, + LONG_EMBED, + MIEB_ENG, + MIEB_IMG, + MIEB_LITE, + MIEB_MULTILINGUAL, + MTEB_DEU, + MTEB_EN, + MTEB_ENG_CLASSIC, + MTEB_EU, + MTEB_FRA, + MTEB_INDIC, + MTEB_JPN, + MTEB_KOR, + MTEB_MAIN_RU, + MTEB_MINERS_BITEXT_MINING, + MTEB_POL, + MTEB_RETRIEVAL_LAW, + MTEB_RETRIEVAL_MEDICAL, + MTEB_RETRIEVAL_WITH_INSTRUCTIONS, + NANOBEIR, + R2MED, + RU_SCI_BENCH, + SEB, + VIDORE, + VIDORE_V2, + VISUAL_DOCUMENT_RETRIEVAL, + VN_MTEB, + CoIR, + MTEB_code, + MTEB_multilingual_v1, + MTEB_multilingual_v2, + RAR_b, +) + +__all__ = [ + "Benchmark", + "MTEB_EN", + "MTEB_ENG_CLASSIC", + "MTEB_MAIN_RU", + "RU_SCI_BENCH", + "MTEB_RETRIEVAL_WITH_INSTRUCTIONS", + "MTEB_RETRIEVAL_LAW", + "MTEB_RETRIEVAL_MEDICAL", + "MTEB_MINERS_BITEXT_MINING", + "SEB", + "CoIR", + "RAR_b", + "MTEB_FRA", + "MTEB_DEU", + "MTEB_KOR", + "MTEB_POL", + "MTEB_code", + "MTEB_multilingual_v1", + "MTEB_multilingual_v2", + "MTEB_JPN", + "MTEB_INDIC", + "MTEB_EU", + "LONG_EMBED", + "BRIGHT", + "BRIGHT_LONG", + "CODE_RAG", + "BEIR", + "NANOBEIR", + "C_MTEB", + "FA_MTEB", + "CHEMTEB", + "BEIR_NL", + "MIEB_ENG", + "MIEB_MULTILINGUAL", + "MIEB_LITE", + "MIEB_IMG", + "BUILT_MTEB", + "ENCODECHKA", + "VIDORE", + "VIDORE_V2", + "VISUAL_DOCUMENT_RETRIEVAL", + "R2MED", + "VN_MTEB", + "JINA_VDR", +] diff --git a/mteb/benchmarks/benchmarks.py b/mteb/benchmarks/benchmarks/benchmarks.py similarity index 99% rename from mteb/benchmarks/benchmarks.py rename to mteb/benchmarks/benchmarks/benchmarks.py index 980f4c4c9d..0bae157929 100644 --- a/mteb/benchmarks/benchmarks.py +++ b/mteb/benchmarks/benchmarks/benchmarks.py @@ -855,7 +855,7 @@ ], ) -MTEB_multilingual = Benchmark( +MTEB_multilingual_v1 = Benchmark( name="MTEB(Multilingual, v1)", display_name="Multilingual", icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-gui-globe.svg", @@ -869,7 +869,7 @@ ) -MTEB_multilingual = Benchmark( +MTEB_multilingual_v2 = Benchmark( name="MTEB(Multilingual, v2)", display_name="Multilingual", icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-gui-globe.svg", diff --git a/mteb/benchmarks/benchmarks/rteb_benchmarks.py b/mteb/benchmarks/benchmarks/rteb_benchmarks.py new file mode 100644 index 0000000000..508009fbc1 --- /dev/null +++ b/mteb/benchmarks/benchmarks/rteb_benchmarks.py @@ -0,0 +1,177 @@ +# RTEB Benchmarks - Retrieval Embedding Benchmark +from __future__ import annotations + +from mteb.benchmarks.benchmark import Benchmark +from mteb.overview import get_tasks + +RTEB_CITATION = r"""@article{rteb2024, + author = {RTEB Authors}, + title = {RTEB: Retrieval Embedding Benchmark for Multi-Domain Text Retrieval}, + year = {2024}, +}""" + +RTEB_MAIN = Benchmark( + name="RTEB(beta)", + display_name="RTEB Retrieval Embedding Benchmark", + icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-gui-search.svg", + tasks=get_tasks( + tasks=[ + "AILACasedocs", + "AILAStatutes", + "LegalSummarization", + "LegalQuAD", + "FinanceBenchRetrieval", + "HC3FinanceRetrieval", + "FinQARetrieval", + "AppsRetrieval", + "DS1000Retrieval", + "HumanEvalRetrieval", + "MBPPRetrieval", + "WikiSQLRetrieval", + "FreshStackRetrieval", + "ChatDoctorRetrieval", + "CUREv1", + ], + ), + description="RTEB (Retrieval Embedding Benchmark) is a comprehensive benchmark for evaluating text retrieval models across multiple specialized domains including legal, finance, code, and healthcare. It contains 15 diverse retrieval tasks designed to test models' ability to understand domain-specific terminology and retrieve relevant documents in specialized contexts.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_ENGLISH = Benchmark( + name="RTEB(eng, beta)", + display_name="RTEB English", + icon="https://github.com/lipis/flag-icons/raw/refs/heads/main/flags/4x3/us.svg", + tasks=get_tasks( + tasks=[ + "AILACasedocs", + "AILAStatutes", + "LegalSummarization", + "FinanceBenchRetrieval", + "HC3FinanceRetrieval", + "FinQARetrieval", + "AppsRetrieval", + "DS1000Retrieval", + "HumanEvalRetrieval", + "MBPPRetrieval", + "WikiSQLRetrieval", + "FreshStackRetrieval", + "ChatDoctorRetrieval", + "CUREv1", + ], + languages=["eng"], + ), + description="RTEB English subset containing retrieval tasks in English across legal, finance, code, and healthcare domains.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_FRENCH = Benchmark( + name="RTEB(fr, beta)", + display_name="RTEB French", + icon="https://github.com/lipis/flag-icons/raw/260c91531be024944c6514130c5defb2ebb02b7d/flags/4x3/fr.svg", + tasks=get_tasks( + tasks=[ + "CUREv1", + ], + languages=["fra"], + ), + description="RTEB French subset containing retrieval tasks in French across multiple domains.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_GERMAN = Benchmark( + name="RTEB(deu, beta)", + display_name="RTEB German", + icon="https://github.com/lipis/flag-icons/raw/260c91531be024944c6514130c5defb2ebb02b7d/flags/4x3/de.svg", + tasks=get_tasks( + tasks=[ + "LegalQuAD", + ], + ), + description="RTEB German subset containing retrieval tasks in German, focusing on legal domain.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_JAPANESE = Benchmark( + name="RTEB(jpn, beta)", + display_name="RTEB Japanese", + icon="https://github.com/lipis/flag-icons/raw/260c91531be024944c6514130c5defb2ebb02b7d/flags/4x3/jp.svg", + tasks=get_tasks( + tasks=[ + # Japanese tasks would go here when available + ], + ), + description="RTEB Japanese subset containing retrieval tasks in Japanese across multiple domains.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_FINANCE = Benchmark( + name="RTEB(fin, beta)", + display_name="RTEB Finance", + icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-gui-price-tag.svg", + tasks=get_tasks( + tasks=[ + "FinanceBenchRetrieval", + "HC3FinanceRetrieval", + "FinQARetrieval", + ], + ), + description="RTEB Finance subset containing retrieval tasks specifically focused on financial domain including finance benchmarks, Q&A, and financial document retrieval.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_LEGAL = Benchmark( + name="RTEB(Law, beta)", + display_name="RTEB Legal", + icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-map-library.svg", + tasks=get_tasks( + tasks=[ + "AILACasedocs", + "AILAStatutes", + "LegalSummarization", + "LegalQuAD", + ], + ), + description="RTEB Legal subset containing retrieval tasks specifically focused on legal domain including case documents, statutes, legal summarization, and legal Q&A.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_CODE = Benchmark( + name="RTEB(Code, beta)", + display_name="RTEB Code", + icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-tech-electronics.svg", + tasks=get_tasks( + tasks=[ + "AppsRetrieval", + "DS1000Retrieval", + "HumanEvalRetrieval", + "MBPPRetrieval", + "WikiSQLRetrieval", + "FreshStackRetrieval", + ], + ), + description="RTEB Code subset containing retrieval tasks specifically focused on programming and code domains including algorithmic problems, data science tasks, code evaluation, and SQL retrieval.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) + +RTEB_HEALTHCARE = Benchmark( + name="RTEB(Health, beta)", + display_name="RTEB Healthcare", + icon="https://github.com/DennisSuitters/LibreICONS/raw/2d2172d15e3c6ca03c018629d60050e4b99e5c55/svg-color/libre-map-hospital.svg", + tasks=get_tasks( + tasks=[ + "ChatDoctorRetrieval", + "CUREv1", + ], + ), + description="RTEB Healthcare subset containing retrieval tasks specifically focused on healthcare and medical domains including medical Q&A, healthcare information retrieval, and cross-lingual medical retrieval.", + citation=RTEB_CITATION, + contacts=["fzowl"], +) diff --git a/mteb/benchmarks/get_benchmark.py b/mteb/benchmarks/get_benchmark.py index bbf4fbfe50..6c2a4382aa 100644 --- a/mteb/benchmarks/get_benchmark.py +++ b/mteb/benchmarks/get_benchmark.py @@ -24,7 +24,7 @@ SEB, Benchmark, MTEB_code, - MTEB_multilingual, + MTEB_multilingual_v2, ) logger = logging.getLogger(__name__) @@ -48,7 +48,7 @@ "MTEB(kor)": MTEB_KOR.name, "MTEB(pol)": MTEB_POL.name, "MTEB(code)": MTEB_code.name, - "MTEB(Multilingual)": MTEB_multilingual.name, + "MTEB(Multilingual)": MTEB_multilingual_v2.name, "MTEB(jpn)": MTEB_JPN.name, "MTEB(Indic)": MTEB_INDIC.name, "MTEB(Europe)": MTEB_EU.name, diff --git a/mteb/leaderboard/app.py b/mteb/leaderboard/app.py index 2dcb4d96be..3c0921ab05 100644 --- a/mteb/leaderboard/app.py +++ b/mteb/leaderboard/app.py @@ -28,7 +28,6 @@ logger = logging.getLogger(__name__) - LANGUAGE: list[str] = list({l for t in mteb.get_tasks() for l in t.metadata.languages}) ALL_MODELS = {meta.name for meta in mteb.get_model_metas()} @@ -54,8 +53,9 @@ def produce_benchmark_link(benchmark_name: str, request: gr.Request) -> str: } ) base_url = request.request.base_url + md = "You can also share this benchmark using the following link:\n" url = f"{base_url}?{params}" - md = f"```\n{url}\n```" + md += f"```\n{url}\n```" return md @@ -73,7 +73,8 @@ def download_table(table: pd.DataFrame) -> str: def update_citation(benchmark_name: str) -> str: benchmark = mteb.get_benchmark(benchmark_name) if benchmark.citation is not None: - citation = f"```bibtex\n{benchmark.citation}\n```" + citation = "To cite this work, please use the following reference:\n" + citation += f"```bibtex\n{benchmark.citation}\n```" else: citation = "" return citation @@ -297,98 +298,91 @@ def get_leaderboard_app() -> gr.Blocks: update_description, inputs=[benchmark_select, lang_select, type_select, domain_select], ) - with gr.Accordion("Cite this benchmark:", open=False): + + with gr.Column(scale=1): + with gr.Accordion("Cite and share this benchmark", open=False): citation = gr.Markdown(update_citation, inputs=[benchmark_select]) # noqa: F841 - with gr.Accordion("Share this benchmark:", open=False): gr.Markdown(produce_benchmark_link, inputs=[benchmark_select]) - with gr.Column(scale=2): - with gr.Tab("Performance per Model Size"): - plot = gr.Plot(performance_size_plot, inputs=[summary_table]) # noqa: F841 - gr.Markdown( - "*We only display models that have been run on all tasks in the benchmark*" - ) - with gr.Tab("Performance per Task Type (Radar Chart)"): - radar_plot = gr.Plot(radar_chart, inputs=[summary_table]) # noqa: F841 - gr.Markdown( - "*We only display models that have been run on all task types in the benchmark*" - ) - - with gr.Accordion("Customize this Benchmark", open=False): - with gr.Column(): - with gr.Row(): - type_select.render() - with gr.Row(): - domain_select.render() - with gr.Row(): - modality_select.render() - with gr.Row(elem_classes="overflow-y-scroll max-h-80"): - lang_select.render() - with gr.Row(elem_classes="overflow-y-scroll max-h-80"): - task_select.render() - - with gr.Accordion("Advanced Model Filters", open=False): - with gr.Group(): - with gr.Row(elem_classes=""): + + with gr.Accordion( + "Customize this Benchmark", + open=False, + ): with gr.Column(): - compatibility = gr.CheckboxGroup( - [ - ( - "Should be sentence-transformers compatible", - "Sentence Transformers", + with gr.Row(): + type_select.render() + with gr.Row(): + domain_select.render() + with gr.Row(): + modality_select.render() + with gr.Row(elem_classes="overflow-y-scroll max-h-80"): + lang_select.render() + with gr.Row(elem_classes="overflow-y-scroll max-h-80"): + task_select.render() + + with gr.Accordion("Advanced Model Filters", open=False): + with gr.Group(): + with gr.Row(elem_classes=""): + with gr.Column(): + compatibility = gr.CheckboxGroup( + [ + ( + "Should be sentence-transformers compatible", + "Sentence Transformers", + ) + ], + value=[], + label="Compatibility", + interactive=True, + ) + availability = gr.Radio( + [ + ("Only Open", True), + ("Only Proprietary", False), + ("Both", None), + ], + value=None, + label="Availability", + interactive=True, + ) + instructions = gr.Radio( + [ + ("Only Instruction-tuned", True), + ("Only non-instruction", False), + ("Both", None), + ], + value=None, + label="Instructions", + interactive=True, + ) + with gr.Column(): + zero_shot = gr.Radio( + [ + ( + "Only Zero-shot", + "only_zero_shot", + ), + ("Remove Unknown", "remove_unknown"), + ("Allow All", "allow_all"), + ], + value="allow_all", + label="Zero-shot", + interactive=True, + ) + + max_model_size = gr.Radio( + [ + ("<100M", 100), + ("<500M", 500), + ("<1B", 1000), + ("<5B", 5000), + ("<10B", 10000), + (">10B", MAX_MODEL_SIZE), + ], + value=MAX_MODEL_SIZE, + label="Model Parameters", + interactive=True, ) - ], - value=[], - label="Compatibility", - interactive=True, - ) - availability = gr.Radio( - [ - ("Only Open", True), - ("Only Proprietary", False), - ("Both", None), - ], - value=None, - label="Availability", - interactive=True, - ) - instructions = gr.Radio( - [ - ("Only Instruction-tuned", True), - ("Only non-instruction", False), - ("Both", None), - ], - value=None, - label="Instructions", - interactive=True, - ) - with gr.Column(): - zero_shot = gr.Radio( - [ - ( - "Only Zero-shot", - "only_zero_shot", - ), - ("Remove Unknown", "remove_unknown"), - ("Allow All", "allow_all"), - ], - value="allow_all", - label="Zero-shot", - interactive=True, - ) - - max_model_size = gr.Radio( - [ - ("<100M", 100), - ("<500M", 500), - ("<1B", 1000), - ("<5B", 5000), - ("<10B", 10000), - (">10B", MAX_MODEL_SIZE), - ], - value=MAX_MODEL_SIZE, - label="Model Parameters", - interactive=True, - ) with gr.Tab("Summary"): summary_table.render() @@ -402,6 +396,25 @@ def get_leaderboard_app() -> gr.Blocks: open=False, ): gr.Markdown(FAQ) + + with gr.Tab("Performance per Model Size") as plot_tab: + plot = gr.Plot(performance_size_plot, inputs=[summary_table]) # noqa: F841 + gr.Markdown( + "*We only display TOP 5 models that have been run on all tasks in the benchmark*" + ) + plot_tab.select( + performance_size_plot, inputs=[summary_table], outputs=[plot] + ) + + with gr.Tab("Performance per Task Type") as radar_plot_tab: + radar_plot = gr.Plot(radar_chart, inputs=[summary_table]) # noqa: F841 + gr.Markdown( + "*We only display TOP 5 models that have been run on all task types in the benchmark*" + ) + radar_plot_tab.select( + radar_chart, inputs=[summary_table], outputs=[radar_plot] + ) + with gr.Tab("Performance per task"): per_task_table.render() download_per_task = gr.DownloadButton("Download Table") diff --git a/mteb/leaderboard/benchmark_selector.py b/mteb/leaderboard/benchmark_selector.py index dff2c8edc9..143d0ef2a7 100644 --- a/mteb/leaderboard/benchmark_selector.py +++ b/mteb/leaderboard/benchmark_selector.py @@ -7,6 +7,9 @@ import mteb from mteb import Benchmark +from mteb.benchmarks.benchmarks import MTEB_multilingual_v2 + +DEFAULT_BENCHMARK_NAME = MTEB_multilingual_v2.name DEFAULT_BENCHMARK_NAME = MTEB_multilingual.name diff --git a/mteb/models/bedrock_models.py b/mteb/models/bedrock_models.py index a97535960c..46e3f02113 100644 --- a/mteb/models/bedrock_models.py +++ b/mteb/models/bedrock_models.py @@ -39,11 +39,7 @@ def __init__( self._provider = provider.lower() if self._provider == "cohere": - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) - if model_prompts - else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) self._max_batch_size = 96 self._max_sequence_length = max_tokens * 4 else: diff --git a/mteb/models/codi_models.py b/mteb/models/codi_models.py index 10c61ff0db..c0016c780f 100644 --- a/mteb/models/codi_models.py +++ b/mteb/models/codi_models.py @@ -12,35 +12,35 @@ codi_instruction = { "CmedqaRetrieval": { "query": "Given a Chinese community medical question, retrieve replies that best answer the question", - "passage": "", + "document": "", }, "CovidRetrieval": { "query": "Given a question on COVID-19, retrieve news articles that answer the question", - "passage": "", + "document": "", }, "DuRetrieval": { "query": "Given a Chinese search query, retrieve web passages that answer the question", - "passage": "", + "document": "", }, "EcomRetrieval": { "query": "Given a user query from an e-commerce website, retrieve description sentences of relevant products", - "passage": "", + "document": "", }, "MedicalRetrieval": { "query": "Given a medical question, retrieve user replies that best answer the question", - "passage": "", + "document": "", }, "MMarcoRetrieval": { "query": "Given a web search query, retrieve relevant passages that answer the query", - "passage": "", + "document": "", }, "T2Retrieval": { "query": "Given a Chinese search query, retrieve web passages that answer the question", - "passage": "", + "document": "", }, "VideoRetrieval": { "query": "Given a video search query, retrieve the titles of relevant videos", - "passage": "", + "document": "", }, "AFQMC": "Represent the text in conversations between users and financial customer service, retrieve semantically similar text", "ATEC": "Represent the text in conversations between users and financial customer service, retrieve semantically similar text", @@ -51,19 +51,19 @@ "STSB": "Represent the short general domain sentences, retrieve semantically similar text", "T2Reranking": { "query": "Given a Chinese search query, retrieve web passages that answer the question", - "passage": "", + "document": "", }, "MMarcoReranking": { "query": "Given a web search query, retrieve relevant passages that answer the query", - "passage": "", + "document": "", }, "CMedQAv1-reranking": { "query": "Given a Chinese community medical question, retrieve replies that best answer the question", - "passage": "", + "document": "", }, "CMedQAv2-reranking": { "query": "Given a Chinese community medical question, retrieve replies that best answer the question", - "passage": "", + "document": "", }, "Ocnli": "Retrieve semantically similar text", "Cmnli": "Retrieve semantically similar text", @@ -83,7 +83,7 @@ def instruction_template( instruction: str, prompt_type: PromptType | None = None ) -> str: - if not instruction or prompt_type == PromptType.document: + if not instruction or prompt_type == PromptType.passage: return "" if isinstance(instruction, dict): if prompt_type is None: diff --git a/mteb/models/cohere_models.py b/mteb/models/cohere_models.py index 606195417a..dbbfd35dfa 100644 --- a/mteb/models/cohere_models.py +++ b/mteb/models/cohere_models.py @@ -135,9 +135,7 @@ def __init__( ) -> None: self.model_name = model_name self.sep = sep - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def _embed( self, diff --git a/mteb/models/colbert_models.py b/mteb/models/colbert_models.py index 82c0e1d1e2..c0e22ea306 100644 --- a/mteb/models/colbert_models.py +++ b/mteb/models/colbert_models.py @@ -40,17 +40,11 @@ def __init__( self.model_name = model_name self.model = colbert_model.ColBERT(self.model_name, revision=revision, **kwargs) - if ( - model_prompts is None - and hasattr(self.model, "prompts") - and len(self.model.prompts) > 0 - ): - try: - model_prompts = self.validate_task_to_prompt_name(self.model.prompts) - except ValueError: - model_prompts = None - elif model_prompts is not None and hasattr(self.model, "prompts"): - logger.info(f"Model prompts will be overwritten with {model_prompts}") + built_in_prompts = getattr(self.model, "prompts", None) + if built_in_prompts and not model_prompts: + model_prompts = built_in_prompts + elif model_prompts and built_in_prompts: + logger.info(f"Model.prompts will be overwritten with {model_prompts}") self.model.prompts = model_prompts self.model_prompts = self.validate_task_to_prompt_name(model_prompts) diff --git a/mteb/models/google_models.py b/mteb/models/google_models.py index 2ef93b261b..9636d1ded2 100644 --- a/mteb/models/google_models.py +++ b/mteb/models/google_models.py @@ -60,9 +60,7 @@ def __init__( **kwargs, ) -> None: self.model_name = model_name - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def _embed( self, diff --git a/mteb/models/llm2vec_models.py b/mteb/models/llm2vec_models.py index 37983bc159..b73678a681 100644 --- a/mteb/models/llm2vec_models.py +++ b/mteb/models/llm2vec_models.py @@ -72,9 +72,7 @@ def __init__( extra_kwargs["attn_implementation"] = "flash_attention_2" - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) if device: kwargs["device_map"] = device diff --git a/mteb/models/mdbr_models.py b/mteb/models/mdbr_models.py new file mode 100644 index 0000000000..490acc9d90 --- /dev/null +++ b/mteb/models/mdbr_models.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from functools import partial + +from mteb.model_meta import ModelMeta, sentence_transformers_loader +from mteb.models.arctic_models import arctic_v1_training_datasets +from mteb.models.mxbai_models import mixedbread_training_data + +model_prompts = {"query": "Represent this sentence for searching relevant passages: "} + +LEAF_TRAINING_DATASETS = { + "AmazonQA": ["train"], + "LoTTE": ["dev", "test"], + # FineWeb + # CC-News + # PubMedQA + # TriviaQA +} + +mdbr_leaf_ir = ModelMeta( + loader=partial( # type: ignore + sentence_transformers_loader, + model_name="MongoDB/mdbr-leaf-ir", + revision="2e46f5aac796e621d51f678c306a66ede4712ecb", + model_prompts=model_prompts, + ), + name="MongoDB/mdbr-leaf-ir", + revision="2e46f5aac796e621d51f678c306a66ede4712ecb", + release_date="2025-08-27", + languages=["eng-Latn"], + open_weights=True, + framework=["Sentence Transformers", "PyTorch"], + n_parameters=22_861_056, + memory_usage_mb=86, + max_tokens=512, + embed_dim=768, + license="apache-2.0", + reference="https://huggingface.co/MongoDB/mdbr-leaf-ir", + similarity_fn_name="cosine", + use_instructions=True, + adapted_from="nreimers/MiniLM-L6-H384-uncased", + superseded_by=None, + public_training_code=None, + public_training_data=None, + training_datasets={**LEAF_TRAINING_DATASETS, **arctic_v1_training_datasets}, +) + +mdbr_leaf_mt = ModelMeta( + loader=partial( # type: ignore + sentence_transformers_loader, + model_name="MongoDB/mdbr-leaf-mt", + revision="66c47ba6d753efc208d54412b5af6c744a39a4df", + model_prompts=model_prompts, + ), + name="MongoDB/mdbr-leaf-mt", + revision="66c47ba6d753efc208d54412b5af6c744a39a4df", + release_date="2025-08-27", + languages=["eng-Latn"], + open_weights=True, + framework=["Sentence Transformers", "PyTorch"], + n_parameters=22_958_592, + memory_usage_mb=86, + max_tokens=512, + embed_dim=1024, + license="apache-2.0", + reference="https://huggingface.co/MongoDB/mdbr-leaf-mt", + similarity_fn_name="cosine", + use_instructions=True, + adapted_from="nreimers/MiniLM-L6-H384-uncased", + superseded_by=None, + public_training_code=None, + public_training_data=None, + training_datasets={**LEAF_TRAINING_DATASETS, **mixedbread_training_data}, +) diff --git a/mteb/models/ordalietech_solon_embeddings_mini_beta_1_1.py b/mteb/models/ordalietech_solon_embeddings_mini_beta_1_1.py new file mode 100644 index 0000000000..7dfd0cc2cd --- /dev/null +++ b/mteb/models/ordalietech_solon_embeddings_mini_beta_1_1.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from functools import partial + +from mteb.model_meta import ModelMeta, sentence_transformers_loader + +solon_embeddings_1_1 = ModelMeta( + name="OrdalieTech/Solon-embeddings-mini-beta-1.1", + languages=["fra-Latn"], + n_parameters=210_000_000, + public_training_code=None, + memory_usage_mb=808.0, + open_weights=True, + revision="8e4ea66eb7eb6109b47b7d97d7556f154d9aec4a", + release_date="2025-01-01", + embed_dim=768, + license="apache-2.0", + max_tokens=8192, + reference="https://huggingface.co/OrdalieTech/Solon-embeddings-mini-beta-1.1", + similarity_fn_name="cosine", + framework=["Sentence Transformers", "PyTorch"], + use_instructions=False, + public_training_data=( + "https://huggingface.co/datasets/PleIAs/common_corpus; " + "https://huggingface.co/datasets/HuggingFaceFW/fineweb; " + "https://huggingface.co/datasets/OrdalieTech/wiki_fr; " + "private LLM-synthetic (train)" + ), + loader=partial( + sentence_transformers_loader, + model_name="OrdalieTech/Solon-embeddings-mini-beta-1.1", + revision="8e4ea66eb7eb6109b47b7d97d7556f154d9aec4a", + trust_remote_code=True, + ), + training_datasets={}, # No mteb dataset +) diff --git a/mteb/models/overview.py b/mteb/models/overview.py index 150570c3cc..edcf8cc883 100644 --- a/mteb/models/overview.py +++ b/mteb/models/overview.py @@ -67,6 +67,7 @@ llm2vec_models, mcinext_models, mctct_model, + mdbr_models, misc_models, mms_models, moco_models, @@ -85,6 +86,7 @@ openclip_models, opensearch_neural_sparse_models, ops_moa_models, + ordalietech_solon_embeddings_mini_beta_1_1, piccolo_models, promptriever_models, qodo_models, @@ -177,6 +179,7 @@ listconranker, llm2clip_models, llm2vec_models, + mdbr_models, misc_models, mctct_model, model2vec_models, @@ -195,6 +198,7 @@ openclip_models, opensearch_neural_sparse_models, ops_moa_models, + ordalietech_solon_embeddings_mini_beta_1_1, piccolo_models, gme_v_models, promptriever_models, diff --git a/mteb/models/qzhou_models.py b/mteb/models/qzhou_models.py index 20e92f3d16..bab9fb969e 100644 --- a/mteb/models/qzhou_models.py +++ b/mteb/models/qzhou_models.py @@ -16,7 +16,7 @@ def instruction_template( instruction: str, prompt_type: PromptType | None = None ) -> str: - if not instruction or prompt_type == PromptType.document: + if not instruction or prompt_type == PromptType.passage: return "" if isinstance(instruction, dict): if prompt_type is None: diff --git a/mteb/models/repllama_models.py b/mteb/models/repllama_models.py index e704cb865a..2a135712e2 100644 --- a/mteb/models/repllama_models.py +++ b/mteb/models/repllama_models.py @@ -47,9 +47,7 @@ def __init__( # set the max_length for the evals as they did, although the model can handle longer self.model.config.max_length = 512 self.tokenizer.model_max_length = 512 - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) def create_batch_dict(self, tokenizer, input_texts): max_length = self.model.config.max_length diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index e9d5492803..d27584b96c 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -39,28 +39,35 @@ def __init__( self.model = model if ( - model_prompts is None - and hasattr(self.model, "prompts") - and len(self.model.prompts) > 0 - ): - try: - model_prompts = self.validate_task_to_prompt_name(self.model.prompts) - - if ( - len(self.model.prompts) == 2 - and self.model.prompts.get("query", "") == "" - and self.model.prompts.get("document", "") == "" - ): - model_prompts = None - except KeyError: - model_prompts = None - logger.warning( - "Model prompts are not in the expected format. Ignoring them." - ) - elif model_prompts is not None and hasattr(self.model, "prompts"): - logger.info(f"Model prompts will be overwritten with {model_prompts}") + built_in_prompts := getattr(self.model, "prompts", None) + ) and not model_prompts: + model_prompts = built_in_prompts + elif model_prompts and built_in_prompts: + logger.warning(f"Model prompts will be overwritten with {model_prompts}") self.model.prompts = model_prompts - self.model_prompts = self.validate_task_to_prompt_name(model_prompts) + + self.model_prompts, invalid_prompts = self.validate_task_to_prompt_name( + model_prompts, raise_for_invalid_keys=False + ) + + if invalid_prompts: + invalid_prompts = "\n".join(invalid_prompts) + logger.warning( + f"Some prompts are not in the expected format and will be ignored. Problems:\n\n{invalid_prompts}" + ) + + if ( + self.model_prompts + and len(self.model_prompts) <= 2 + and ( + PromptType.query.value not in self.model_prompts + or PromptType.document.value not in self.model_prompts + ) + ): + logger.warning( + "SentenceTransformers that use prompts most often need to be configured with at least 'query' and" + f" 'document' prompts to ensure optimal performance. Received {self.model_prompts}" + ) if isinstance(self.model, CrossEncoder): self.predict = self._predict diff --git a/mteb/models/voyage_models.py b/mteb/models/voyage_models.py index 15c934d573..8a222e5d9d 100644 --- a/mteb/models/voyage_models.py +++ b/mteb/models/voyage_models.py @@ -16,6 +16,24 @@ # synthetic data } +# Total token limits per model based on VoyageAI documentation +VOYAGE_TOTAL_TOKEN_LIMITS = { + "voyage-3.5-lite": 1_000_000, + "voyage-3.5": 320_000, + "voyage-2": 320_000, + "voyage-3-large": 120_000, + "voyage-code-3": 120_000, + "voyage-large-2-instruct": 120_000, + "voyage-finance-2": 120_000, + "voyage-multilingual-2": 120_000, + "voyage-law-2": 120_000, + "voyage-large-2": 120_000, + "voyage-3": 120_000, + "voyage-3-lite": 120_000, + "voyage-code-2": 120_000, + "voyage-3-m-exp": 120_000, +} + def token_limit(max_tpm: int, interval: int = 60): limit_interval_start_ts = time.time() @@ -75,6 +93,7 @@ def __init__( max_retries: int = 5, max_rpm: int = 300, max_tpm: int = 1_000_000, + max_tokens: int | None = None, model_prompts: dict[str, str] | None = None, **kwargs, ) -> None: @@ -85,19 +104,32 @@ def __init__( self._embed_func = rate_limit(max_rpm)(token_limit(max_tpm)(self._client.embed)) self._model_name = model_name self._max_tpm = max_tpm - self.model_prompts = ( - self.validate_task_to_prompt_name(model_prompts) if model_prompts else None - ) + self._max_tokens = max_tokens + self.model_prompts = self.validate_task_to_prompt_name(model_prompts) + + def _calculate_default_batch_size(self) -> int: + """Calculate the default batch size based on total token limit and context length. + + Formula: floor(total_token_limit / context_length) + """ + if self._max_tokens is None: + return 32 # fallback to original default + + total_token_limit = VOYAGE_TOTAL_TOKEN_LIMITS.get(self._model_name, 120_000) + return max(1, total_token_limit // self._max_tokens) def encode( self, sentences: list[str], *, - batch_size: int = 32, + batch_size: int | None = None, task_name: str, prompt_type: PromptType | None = None, **kwargs: Any, ) -> np.ndarray: + if batch_size is None: + batch_size = self._calculate_default_batch_size() + prompt_name = self.get_prompt_name(self.model_prompts, task_name, prompt_type) input_type = self.model_prompts.get(prompt_name, "document") @@ -151,6 +183,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3.5", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -176,6 +209,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-large-2-instruct", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -201,6 +235,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-finance-2", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -226,6 +261,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-law-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -251,6 +287,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-code-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -276,6 +313,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-code-3", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -302,6 +340,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-large-2", + max_tokens=16000, model_prompts=model_prompts, ), max_tokens=16000, @@ -327,6 +366,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-2", + max_tokens=4000, model_prompts=model_prompts, ), max_tokens=4000, @@ -351,6 +391,7 @@ def _batched_encode( loader=partial( # type: ignore VoyageWrapper, model_name="voyage-multilingual-2", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -376,6 +417,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -401,6 +443,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3-lite", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, @@ -426,6 +469,7 @@ def _batched_encode( loader=partial( VoyageWrapper, model_name="voyage-3-m-exp", + max_tokens=32000, model_prompts=model_prompts, ), max_tokens=32000, diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index 68ec09ae62..ccbdc59713 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Callable, get_args +from collections.abc import Callable, Sequence +from typing import Literal, get_args, overload import mteb from mteb.abstasks.TaskMetadata import TASK_TYPE @@ -65,29 +66,95 @@ def get_prompt_name( return None @staticmethod + @overload def validate_task_to_prompt_name( - task_to_prompt_name: dict[str, str] | None, - ) -> dict[str, str] | None: - if task_to_prompt_name is None: - return task_to_prompt_name + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: Literal[True] = True, + ) -> dict[str, str] | None: ... + + @staticmethod + @overload + def validate_task_to_prompt_name( + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: Literal[False] = False, + ) -> tuple[dict[str, str], Sequence[str]] | tuple[None, None]: ... + + @staticmethod + def validate_task_to_prompt_name( + task_to_prompt: dict[str, str] | None, + raise_for_invalid_keys: bool = True, + ) -> ( + dict[str, str] | tuple[dict[str, str], Sequence[str]] | tuple[None, None] | None + ): + """Validates that the keys in task_to_prompt_name map to a known task or prompt type. + + A key is valid if: + + 1. It is a valid task name; or + 2. It is a valid task type; or + 3. It is a valid prompt type; or + 4. It is a compound key of the form "{task_name}-{prompt_type}" where task_name is a valid task type or task + name and prompt_type is a valid prompt type. + + See the + [MTEB docs](https://github.com/embeddings-benchmark/mteb/blob/main/docs/usage/usage.md#running-sentencetransformer-model-with-prompts) + for a complete description of the order or precedence for these keys when running an evaluation. + + Arguments: + task_to_prompt: The dictionary of prompts. + raise_for_invalid_keys: If True, raise an error when an invalid key is encountered, otherwise return the + list of error messages along with a filtered dictionary of prompts with valid keys. Defaults to True + for backward compatibility. + + Returns: + * None if `task_to_prompt` is None or empty; + * Only a dictionary of validated prompts if `raise_for_invalid_keys` is `True`; or + * A tuple continaing the filtered dictionary of valid prompts and the set of error messages for the + invalid prompts `raise_for_invalid` is `False` + + Raises: + KeyError: If any invlaid keys are encountered and `raise_for_invalid_keys` is `True`, this function will + raise a single `KeyError` contianing the + """ + if not task_to_prompt: + return None if raise_for_invalid_keys else (None, None) + task_types = get_args(TASK_TYPE) prompt_types = [e.value for e in PromptType] - for task_name in task_to_prompt_name: - if "-" in task_name and task_name.endswith( - (f"-{PromptType.query.value}", f"-{PromptType.document.value}") - ): - task_name, prompt_type = task_name.rsplit("-", 1) - if prompt_type not in prompt_types: - msg = f"Prompt type {prompt_type} is not valid. Valid prompt types are {prompt_types}" - logger.warning(msg) - raise KeyError(msg) + valid_keys_msg = f"Valid keys are task types [{task_types}], prompt types [{prompt_types}], and task names" + valid_prompt_type_endings = tuple( + [f"-{prompt_type}" for prompt_type in prompt_types] + ) + + invalid_keys: set[str] = set() + invalid_task_messages: set[str] = set() + + for task_key in task_to_prompt: + # task_key may be a compound key of the form "{task_name}-{prompt_type}". A task_name may contain a "-" + # character (this occurs in ~12% of task names), so rsplit is used to separate a valid prompt_type postfix + # from the unvalidated task_name. + if task_key.endswith(valid_prompt_type_endings): + task_name = task_key.rsplit("-", 1)[0] + else: + task_name = task_key + if task_name not in task_types and task_name not in prompt_types: - task = mteb.get_task(task_name=task_name) - if not task: - msg = f"Task name {task_name} is not valid. Valid task names are task types [{task_types}], prompt types [{prompt_types}] and task names" + try: + mteb.get_task(task_name=task_name) + except KeyError: + msg = f"Task name {task_name} is not valid. {valid_keys_msg}" logger.warning(msg) - raise KeyError(msg) - return task_to_prompt_name + invalid_task_messages.add(msg) + invalid_keys.add(task_key) + + if raise_for_invalid_keys and invalid_task_messages: + raise KeyError(invalid_task_messages) + elif raise_for_invalid_keys: + return task_to_prompt + else: + return { + k: v for k, v in task_to_prompt.items() if k not in invalid_keys + }, tuple(invalid_task_messages) @staticmethod def get_instruction( diff --git a/mteb/overview.py b/mteb/overview.py index 79a5b30e6b..7bb59bb430 100644 --- a/mteb/overview.py +++ b/mteb/overview.py @@ -288,6 +288,7 @@ def get_tasks( modalities: list[MODALITIES] | None = None, exclusive_modality_filter: bool = False, exclude_aggregate: bool = False, + exclude_private: bool = True, ) -> MTEBTasks: """Get a list of tasks based on the specified filters. @@ -311,6 +312,7 @@ def get_tasks( task's modalities and ALL task modalities are in filter modalities (exact match). If False, keep tasks if _any_ of the task's modalities match the filter modalities. exclude_aggregate: If True, exclude aggregate tasks. If False, both aggregate and non-aggregate tasks are returned. + exclude_private: If True (default), exclude private/closed datasets (is_public=False). If False, include both public and private datasets. Returns: A list of all initialized tasks objects which pass all of the filters (AND operation). @@ -364,6 +366,10 @@ def get_tasks( if exclude_aggregate: _tasks = filter_aggregate_tasks(_tasks) + # Apply privacy filtering + if exclude_private: + _tasks = [t for t in _tasks if t.metadata.is_public] + return MTEBTasks(_tasks) diff --git a/pyproject.toml b/pyproject.toml index 2c44dfbd8b..fb28bbc368 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "mteb" -version = "1.38.46" +version = "1.38.52" description = "Massive Text Embedding Benchmark" readme = "README.md" authors = [ @@ -56,16 +56,6 @@ mteb = "mteb.cli:main" [project.optional-dependencies] image = ["torchvision>0.2.1"] -dev = [ -"ruff==0.11.13", # locked so we don't get PRs which fail only due to a lint update -"pytest>=8.3.4", -"pytest-xdist>=3.6.1", -"pytest-coverage>=0.0", -"pytest-rerunfailures>=15.0", -"iso639>=0.1.4", # used for tests/scripts/test_generate_model_meta.py -"pre-commit>=4.1.0", -"bibtexparser>=1.4.3" # used for tests/test_citation_formatting.py -] codecarbon = ["codecarbon>=2.0.0,<3.0.0"] speedtask = [ "setuptools!=78.0.1", # https://github.com/pypa/setuptools/issues/4910 @@ -101,10 +91,28 @@ open_clip_torch = ["open_clip_torch==2.31.0"] xet = ["huggingface_hub>=0.32.0"] ark = ["volcengine-python-sdk[ark]==3.0.2", "tiktoken>=0.8.0"] colpali_engine = ["colpali_engine>=0.3.12"] + +[dependency-groups] +lint = [ + "ruff==0.11.13", # locked so we don't get PRs which fail only due to a lint update + "pre-commit>=4.1.0", + "bibtexparser>=1.4.3" # used for tests/test_citation_formatting.py +] +test = [ + "pytest>=8.3.4,<8.4.0", + "pytest-xdist>=3.6.1,<3.7.0", + "pytest-coverage>=0.0", + "pytest-rerunfailures>=15.0,<16.0", + "iso639>=0.1.4", # used for tests/scripts/test_generate_model_meta.py +] +dev = [ + {include-group = "lint"}, + {include-group = "test"}, +] + muq = ["muq==0.1.0"] [tool.coverage.report] - omit = ["tests/*", "mteb/tasks/**/*", "scripts"] # Regexes for lines to exclude from consideration diff --git a/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py b/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py index 4ca0056cd7..be5377ea73 100644 --- a/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py +++ b/tests/test_benchmark/test_benchmark_integration_with_sentencetransformers.py @@ -24,7 +24,13 @@ ) def test_benchmark_sentence_transformer(task: str | AbsTask, model_name: str): """Test that a task can be fetched and run""" - if isinstance(model_name, str): - model = SentenceTransformer(model_name) + model = SentenceTransformer(model_name) + # Prior to https://github.com/embeddings-benchmark/mteb/pull/3079 the + # SentenceTransformerWrapper would set the model's prompts to None because + # the mock tasks are not in the MTEB task registry. The linked PR changes + # this behavior and keeps the prompts as configured by the model, so this + # test now sets the prompts to None explicitly to preserve the legacy + # behavior and focus the test on the tasks instead of the prompts. + model.prompts = None eval = MTEB(tasks=[task]) eval.run(model, output_folder="tests/results", overwrite_results=True) diff --git a/tests/test_overview.py b/tests/test_overview.py index 801929817c..e812ac743f 100644 --- a/tests/test_overview.py +++ b/tests/test_overview.py @@ -169,3 +169,21 @@ def test_get_tasks_with_exclusive_modality_filter(modalities): ) for task in text_tasks_exclusive: assert set(task.modalities) == set(modalities) + + +def test_get_tasks_privacy_filtering(): + """Test that get_tasks correctly filters by privacy status""" + # By default, should only return public datasets (exclude_private=True) + public_tasks = get_tasks() + + # Should include private datasets when explicitly requested + all_tasks = get_tasks(exclude_private=False) + + # All tasks should contain at least as many or more tasks than public tasks + assert len(all_tasks) >= len(public_tasks) + + # All returned tasks should be public when exclude_private=True + for task in public_tasks: + assert ( + task.metadata.is_public is not False + ) # None or True are both considered public diff --git a/tests/test_reproducible_workflow.py b/tests/test_reproducible_workflow.py index 738392c623..48f64e6496 100644 --- a/tests/test_reproducible_workflow.py +++ b/tests/test_reproducible_workflow.py @@ -14,9 +14,16 @@ logging.basicConfig(level=logging.INFO) -@pytest.mark.parametrize("task_name", ["BornholmBitextMining"]) -@pytest.mark.parametrize("model_name", ["sentence-transformers/all-MiniLM-L6-v2"]) -@pytest.mark.parametrize("model_revision", ["8b3219a92973c328a8e22fadcfa821b5dc75636a"]) +@pytest.mark.parametrize( + "task_name, model_name, model_revision", + [ + ( + "BornholmBitextMining", + "sentence-transformers/all-MiniLM-L6-v2", + "8b3219a92973c328a8e22fadcfa821b5dc75636a", + ), + ], +) def test_reproducibility_workflow(task_name: str, model_name: str, model_revision: str): """Test that a model and a task can be fetched and run in a reproducible fashion.""" model_meta = mteb.get_model_meta(model_name, revision=model_revision) @@ -67,11 +74,51 @@ def test_validate_task_to_prompt_name(task_name: str | mteb.AbsTask): Wrapper.validate_task_to_prompt_name(model_prompts) -def test_validate_task_to_prompt_name_fail(): - with pytest.raises(KeyError): - Wrapper.validate_task_to_prompt_name( - {"task_name": "prompt_name", "task_name-query": "prompt_name"} - ) +@pytest.mark.parametrize("raise_for_invalid_keys", (True, False)) +def test_validate_task_to_prompt_name_for_none(raise_for_invalid_keys: bool): + result = Wrapper.validate_task_to_prompt_name( + None, raise_for_invalid_keys=raise_for_invalid_keys + ) + assert result is None if raise_for_invalid_keys else (None, None) + +@pytest.mark.parametrize( + "task_prompt_dict", + [ + {"task_name": "prompt_name"}, + {"task_name-query": "prompt_name"}, + {"task_name-task_name": "prompt_name"}, + ], +) +def test_validate_task_to_prompt_name_fails_and_raises( + task_prompt_dict: dict[str, str], +): with pytest.raises(KeyError): - Wrapper.validate_task_to_prompt_name({"task_name-task_name": "prompt_name"}) + Wrapper.validate_task_to_prompt_name(task_prompt_dict) + + +@pytest.mark.parametrize( + "task_prompt_dict, expected_valid, expected_invalid", + [ + ({"task_name": "prompt_name"}, 0, 1), + ({"task_name-query": "prompt_name"}, 0, 1), + ( + { + "task_name-query": "prompt_name", + "query": "prompt_name", + "Retrieval": "prompt_name", + }, + 2, + 1, + ), + ({"task_name-task_name": "prompt_name"}, 0, 1), + ], +) +def test_validate_task_to_prompt_name_filters_and_reports( + task_prompt_dict: dict[str, str], expected_valid: int, expected_invalid: int +): + valid, invalid = Wrapper.validate_task_to_prompt_name( + task_prompt_dict, raise_for_invalid_keys=False + ) + assert len(valid) == expected_valid + assert len(invalid) == expected_invalid diff --git a/tests/test_tasks/test_private_tasks.py b/tests/test_tasks/test_private_tasks.py new file mode 100644 index 0000000000..50139dbdfd --- /dev/null +++ b/tests/test_tasks/test_private_tasks.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import pytest + +from mteb.overview import get_task, get_tasks + +# List of accepted private tasks - update this list as needed +ACCEPTED_PRIVATE_TASKS = [ + # Add task names here that are allowed to be private + # Example: "SomePrivateTask" +] + + +def test_private_tasks_fail_unless_accepted(): + """Test that private tasks (is_public=False) fail unless they are in the accepted list.""" + # Get all tasks including private ones + all_tasks = get_tasks(exclude_private=False) + + # Find all private tasks + private_tasks = [task for task in all_tasks if task.metadata.is_public is False] + + # Check that all private tasks are in the accepted list + for task in private_tasks: + assert task.metadata.name in ACCEPTED_PRIVATE_TASKS, ( + f"Private task '{task.metadata.name}' is not in the accepted private tasks list. " + f"Either make the task public (is_public=True) or add it to ACCEPTED_PRIVATE_TASKS." + ) + + +@pytest.mark.parametrize("task_name", ACCEPTED_PRIVATE_TASKS) +def test_accepted_private_task_exist(task_name: str): + """Test that all tasks in ACCEPTED_PRIVATE_TASKS actually exist and are private.""" + task = get_task(task_name) + assert task.metadata.is_public == ( + f"Task '{task_name}' is in ACCEPTED_PRIVATE_TASKS but is not private (is_public={task.metadata.is_public})" + )