diff --git a/so_vector/challenges/default.json b/so_vector/challenges/default.json index af3d7c0c2..279ce01df 100644 --- a/so_vector/challenges/default.json +++ b/so_vector/challenges/default.json @@ -47,6 +47,13 @@ "iterations": 100, "clients": 1 }, + { + "name": "knn-recall-10-50-match-all", + "operation": "knn-recall-10-50-match-all", + "warmup-iterations": 1, + "iterations": 1, + "clients": 1 + }, { "name": "script-score-query-match-all", "operation": "script-score-query-match-all", @@ -132,6 +139,13 @@ "iterations": 100, "clients": 1 }, + { + "name": "knn-recall-10-50-match-all-force-merge", + "operation": "knn-recall-10-50-match-all", + "warmup-iterations": 1, + "iterations": 1, + "clients": 1 + }, { "name": "knn-search-10-50-acceptedAnswerId-force-merge", "operation": "knn-search-10-50-acceptedAnswerId", diff --git a/so_vector/index.json b/so_vector/index.json index c3252f917..7da5cd8cc 100644 --- a/so_vector/index.json +++ b/so_vector/index.json @@ -23,7 +23,7 @@ "type": "keyword" }, "questionId": { - "type": "keyword" + "type": "long" }, "creationDate": { "type": "date" diff --git a/so_vector/operations/default.json b/so_vector/operations/default.json index 40d3aa926..e35469e58 100644 --- a/so_vector/operations/default.json +++ b/so_vector/operations/default.json @@ -53,6 +53,14 @@ "k": 10, "num_candidates": 50 }, +{ + "name": "knn-recall-10-50-match-all", + "operation-type": "knn-recall", + "param-source": "knn-recall-param-source", + "k": 10, + "num_candidates": 50, + "include-in-reporting": false +}, { "name": "script-score-query-css", "operation-type": "search", diff --git a/so_vector/queries-1k.json.bz2 b/so_vector/queries-1k.json.bz2 new file mode 100644 index 000000000..dd60e561b Binary files /dev/null and b/so_vector/queries-1k.json.bz2 differ diff --git a/so_vector/queries-recall-1k.json.bz2 b/so_vector/queries-recall-1k.json.bz2 new file mode 100644 index 000000000..21d582ce5 Binary files /dev/null and b/so_vector/queries-recall-1k.json.bz2 differ diff --git a/so_vector/queries-recall.json.bz2 b/so_vector/queries-recall.json.bz2 new file mode 100644 index 000000000..a2cd9bcac Binary files /dev/null and b/so_vector/queries-recall.json.bz2 differ diff --git a/so_vector/queries.json b/so_vector/queries.json deleted file mode 100644 index 51f89972a..000000000 --- a/so_vector/queries.json +++ /dev/null @@ -1 +0,0 @@ -[-0.03565507382154465,0.029150789603590965,-0.009953430853784084,0.016862303018569946,0.007259797770529985,0.010058729909360409,0.0031890675891190767,0.013714447617530823,-0.06062733009457588,-0.02556275762617588,-0.030024470761418343,-0.02655009739100933,-0.021072229370474815,-0.020553242415189743,0.0024740989319980145,-0.02194809541106224,0.03757188096642494,-0.009190731681883335,-0.007803409826010466,-0.017052657902240753,-0.006546805612742901,0.03078766167163849,0.025903189554810524,-0.046361058950424194,-0.03663375601172447,-0.058112241327762604,0.02732008509337902,0.0513170026242733,-0.002910953015089035,0.024234339594841003,-0.033778537064790726,-0.04715698957443237,0.055929772555828094,0.05261250585317612,-0.00042414062772877514,-0.032236892729997635,-0.01921016350388527,0.01580752618610859,-0.005356852896511555,0.10626103729009628,-0.0013599975500255823,-0.01984083652496338,0.03308602795004845,0.030549651011824608,-0.0212569460272789,0.06032968685030937,0.023432200774550438,-0.06202605366706848,0.010168108157813549,0.019757689908146858,0.022075966000556946,-0.027184976264834404,0.05304407700896263,0.01991206593811512,-0.05874328315258026,-0.022372309118509293,0.010880714282393456,0.10801274329423904,0.005823677871376276,0.10308969765901566,-0.002078634686768055,-0.07874203473329544,0.003924667835235596,0.038136523216962814,-0.06289270520210266,-0.014300704002380371,0.02404375746846199,0.0571337528526783,-0.021905988454818726,0.013701377436518669,-0.019309641793370247,0.002611628035083413,-0.02021665871143341,0.05289266258478165,-0.04801682010293007,-0.04287039861083031,0.0259495060890913,-0.03272141143679619,0.03230957314372063,0.03971446678042412,0.032413799315690994,-0.08219734579324722,0.023043571040034294,0.022732242941856384,-0.06296930462121964,0.02008066326379776,-0.04775219410657883,0.03778718411922455,-0.047711946070194244,-0.02974596619606018,-0.033333372324705124,0.05168253183364868,0.008533110842108727,-0.0011690593091771007,0.01518457755446434,0.0044088042341172695,-0.021822111681103706,0.017584646120667458,0.02231636457145214,-0.021504264324903488,0.040364328771829605,0.044239919632673264,-0.06553740799427032,0.01948888599872589,-0.009774748235940933,0.053420402109622955,-0.061048950999975204,0.007991021499037743,0.02122330851852894,-0.03152592480182648,0.0690419152379036,-0.01666679047048092,-0.02897976152598858,-0.024939894676208496,0.08396109938621521,-0.018203193321824074,0.028639059513807297,-0.019337370991706848,-0.005825735628604889,-0.02283378317952156,0.062039878219366074,-0.0321921706199646,0.03829389438033104,-0.03668420761823654,-0.019975509494543076,0.07674495130777359,-0.03523276001214981,0.00017863407265394926,-0.0038612934295088053,-0.05285559594631195,-0.03689549118280411,0.002635740442201495,0.02099185436964035,0.02761271595954895,0.030929066240787506,0.001502262195572257,0.009322592057287693,0.050910238176584244,-0.0344817228615284,-0.01972089521586895,-0.0522887147963047,0.0784289762377739,0.01015879400074482,-0.029476989060640335,-0.022646889090538025,-0.056788817048072815,-0.0008934522629715502,-0.04874531552195549,-0.00967901386320591,0.01471779216080904,-0.025908494368195534,0.051237769424915314,-0.10946392267942429,0.022627808153629303,0.06296370178461075,-0.06527986377477646,-0.057765111327171326,0.008419831283390522,0.05475125089287758,0.0014037907822057605,0.04957624524831772,0.022818686440587044,-0.010850337333977222,0.020566903054714203,-0.00596493249759078,-0.026529962196946144,-0.07000315934419632,-0.0005925678415223956,0.03815511614084244,-0.0036679604090750217,0.04319975525140762,0.020610583946108818,0.016138756647706032,-0.028218422085046768,0.004382557235658169,-0.04244203492999077,-0.03253130614757538,0.029496192932128906,0.06442058831453323,-0.05917511507868767,-0.05538521707057953,-0.06622880697250366,-0.02514919638633728,0.00997821893543005,0.03900130093097687,-0.003974937833845615,0.07689959555864334,-0.02971455454826355,-0.041060272604227066,-0.0384940430521965,0.009285518899559975,-0.020366281270980835,-0.00552419014275074,0.010278463363647461,-0.03591599687933922,-0.026300232857465744,-0.004174963571131229,0.025668716058135033,-0.010733183473348618,0.003561246208846569,0.01068275049328804,-0.04099344089627266,-0.03305772319436073,-0.01103952620178461,-0.01176391076296568,-0.03812599554657936,0.019757866859436035,0.0012890181969851255,-0.02542475424706936,-0.03283122181892395,0.023626361042261124,-0.015137449838221073,0.009961884468793869,-0.007350832689553499,0.032800860702991486,-0.003159715561196208,0.0064643616788089275,-0.03522152081131935,0.040771499276161194,0.019031954929232597,-0.009723273105919361,-0.073537178337574,0.07973514497280121,0.03280383348464966,0.06543240696191788,-0.06359760463237762,0.04215288534760475,-0.027418429031968117,-0.03710818663239479,0.009744578041136265,0.013476579450070858,0.017894737422466278,0.007020334713160992,0.007983843795955181,-0.03929414227604866,-0.023337315768003464,0.004379247780889273,-0.004015898797661066,0.024705275893211365,-0.0027054743841290474,0.007910565473139286,0.08340980112552643,0.01599128544330597,-0.0018527612555772066,0.00031564742675982416,0.005881187040358782,-0.019030582159757614,-0.04534437879920006,-0.03709614276885986,0.04108308628201485,0.022330619394779205,0.024027260020375252,-0.031039586290717125,-0.0072877188213169575,-0.03873248025774956,-0.028300173580646515,-0.0015663864323869348,-0.013528351671993732,-0.04544748738408089,-0.008431107737123966,-0.004370458889752626,0.0236565712839365,0.031668439507484436,-0.018949592486023903,-0.045675646513700485,0.03212239220738411,-0.009707520715892315,-0.030808836221694946,0.023984191939234734,-0.02123229205608368,0.04221543297171593,0.016728518530726433,-0.007254659663885832,0.0025122223887592554,-0.03743964433670044,-0.05043407902121544,0.017068667337298393,-0.08743435144424438,0.011513011530041695,0.01634378172457218,0.0014041990507394075,0.09553798288106918,0.008838274516165257,-0.04745299369096756,-0.02695905975997448,-0.034419506788253784,-0.0018368709133937955,0.07157888263463974,-0.023965518921613693,0.09751397371292114,0.027433643117547035,0.026339473202824593,0.010309661738574505,0.01399142574518919,0.02942475862801075,-0.04502158612012863,0.007889588363468647,-0.03771588206291199,-0.0017137943068519235,0.02323855645954609,-0.0021091902162879705,0.05250275880098343,-0.10714013874530792,0.004470758605748415,0.02374192513525486,-0.0025698181707412004,0.018200701102614403,0.019866930320858955,-0.0265716053545475,-0.02935563400387764,-0.03446895629167557,-0.05361831560730934,0.07387492805719376,0.004163479432463646,-0.01775449514389038,0.01857665181159973,-0.00979716144502163,0.0191548652946949,0.029887260869145393,-0.007130764424800873,0.025345578789711,-0.015894543379545212,-0.020138850435614586,0.0020117126405239105,0.0012763311387971044,0.04748288914561272,0.0027929546777158976,0.025585776194930077,-0.013460710644721985,-0.028746776282787323,0.0005618438590317965,-0.11090100556612015,-0.008552717976272106,0.01970895752310753,0.06780264526605606,-0.005113844759762287,-0.02344670332968235,0.009334192611277103,-0.009444670751690865,-0.005409246310591698,0.018911344930529594,0.008832687512040138,-0.0006177318282425404,0.0062277065590023994,-0.036681827157735825,0.11252965778112411,-0.028750795871019363,-0.023477429524064064,-0.06286631524562836,-0.04093959182500839,0.038868390023708344,0.07037753611803055,-0.0032989168539643288,-0.010539747774600983,0.0013698963448405266,0.009377953596413136,0.05339144915342331,0.010542218573391438,0.03913126140832901,0.021044984459877014,-0.00604604184627533,0.01088661141693592,-0.007079940289258957,0.025334985926747322,-0.024322209879755974,-0.04964429512619972,-0.014856045134365559,-0.025356050580739975,0.04669710993766785,0.035462357103824615,-0.012174313887953758,-0.019514625892043114,0.00023600005079060793,0.016223324462771416,0.024001216515898705,-0.04775359481573105,-0.017870768904685974,-0.03890656307339668,-0.05119255930185318,-0.033654578030109406,-0.03526986017823219,-0.05874961242079735,-0.04393085092306137,-0.026779767125844955,-0.03717479854822159,0.030165502801537514,0.002373858354985714,0.014556492678821087,-0.04489317536354065,-0.06262949854135513,0.007774992845952511,-0.015520785003900528,0.027643904089927673,-0.04647809639573097,0.01571064069867134,-0.003552223090082407,0.05668463557958603,-0.007884553633630276,0.025027187541127205,0.02255728840827942,-0.02136903628706932,-0.04616555571556091,0.06839678436517715,0.01048994716256857,-0.02151329070329666,0.01621435023844242,-0.0016323613235726953,0.001860313699580729,-0.018296919763088226,0.05440540984272957,-0.03269590064883232,-0.01972234807908535,-0.02895471639931202,-0.006595495156943798,-0.037501174956560135,-0.015340996906161308,0.021535035222768784,0.04605574160814285,0.03019796870648861,-0.018932485952973366,0.035491324961185455,0.0430041141808033,0.014098184183239937,0.02236126735806465,0.009802297689020634,-0.02696816623210907,-0.0028520182240754366,0.013460146263241768,0.04213448241353035,-0.06632637232542038,-0.0136404475197196,0.016549309715628624,-0.03244054690003395,0.019699087366461754,-0.00047389615792781115,-0.05903944373130798,0.00415915809571743,0.04429100826382637,0.0038725859485566616,0.0031527173705399036,0.025607014074921608,-0.008281109854578972,-0.15848775207996368,0.04422315955162048,0.0015401487471535802,0.06380972266197205,0.01321488432586193,-0.008954356424510479,-0.021055836230516434,-0.02416614629328251,0.05596429482102394,0.0069205984473228455,0.03783739358186722,0.04838406667113304,0.02909012883901596,0.014646432362496853,0.019111914560198784,-0.02021666057407856,-0.003243596525862813,-0.03501174971461296,0.049577727913856506,-0.04407232627272606,0.01626511663198471,-0.009797878563404083,0.026676828041672707,-0.030644778162240982,0.01642543263733387,0.009930545464158058,-0.01971464231610298,0.003496440825983882,0.07091105729341507,-0.005778995342552662,0.026004401966929436,0.013516523875296116,-0.011916893534362316,-0.036831535398960114,0.04016464948654175,-0.12716856598854065,-0.01848728582262993,0.033379532396793365,0.02160184271633625,-0.043847039341926575,-0.03579642251133919,0.11792618036270142,-0.0029324907809495926,-0.01063877809792757,-0.06984290480613708,0.003830571426078677,0.007833411917090416,0.013447368517518044,-0.002347102388739586,0.007095981854945421,-0.03578951582312584,-0.06019451841711998,0.031073935329914093,-0.007177609950304031,0.04344784468412399,-0.004873192869126797,-0.009990269318223,0.07387205958366394,-0.044872477650642395,-0.0096746189519763,0.007731420453637838,-0.0031587895937263966,0.04997938871383667,-0.016994107514619827,-0.0016591385938227177,0.017520029097795486,-0.021097693592309952,0.010463709011673927,-0.009139510802924633,0.01673535816371441,0.06014763191342354,0.03000626340508461,0.024423616006970406,-0.020157745108008385,-0.044746242463588715,0.04752214252948761,-0.022302236407995224,0.007747003808617592,-0.0021971596870571375,0.012966714799404144,0.037624627351760864,-0.024573836475610733,0.015103555284440517,0.05745525285601616,0.01937824673950672,-0.010667713358998299,-0.00689451489597559,-0.014436144381761551,-0.046800050884485245,0.0036504995077848434,0.037052929401397705,0.017737025395035744,0.014738512225449085,0.0008316689054481685,-0.003302399767562747,-0.033323682844638824,0.017181942239403725,0.023572782054543495,-0.004520328249782324,0.004148002248257399,-0.03653273731470108,-0.013946603052318096,0.02103128284215927,-0.034732408821582794,0.03156023100018501,0.02231747657060623,-0.019721290096640587,0.008886530064046383,0.0014712290139868855,0.009843145497143269,0.06848130375146866,-0.003913266584277153,-0.05590398609638214,0.002325712237507105,0.035733092576265335,-0.007341672666370869,-0.00797269493341446,0.03695487603545189,-0.013041575439274311,-0.05316924303770065,0.07956130802631378,0.007408377714455128,-0.019228488206863403,-0.0009132683044299483,-0.02691648341715336,0.013592048548161983,-0.03143993392586708,-0.04238665848970413,-0.002541846828535199,0.06360283493995667,0.025705885142087936,-0.07138311117887497,-0.012239706702530384,-0.04114243760704994,0.0009787464514374733,-0.03114210069179535,0.028999269008636475,0.004927006084471941,0.04170137271285057,0.007398371584713459,-0.02399301528930664,0.024065770208835602,0.0016199419042095542,-0.03118061274290085,0.029806338250637054,0.026011748239398003,0.01648767665028572,-0.045947860926389694,0.006785869598388672,0.08104468137025833,-0.04829707741737366,-0.04585522040724754,-0.0030510155484080315,-0.015787940472364426,0.0554991140961647,-0.015017300844192505,0.013355118222534657,0.03120446391403675,0.05623067170381546,-0.04079202562570572,-0.01882946863770485,0.035098183900117874,0.023707108572125435,-0.001042348681949079,-0.04353583604097366,-0.008083582855761051,0.06760680675506592,-0.009482468478381634,0.020736562088131905,0.0025044428184628487,-0.025875352323055267,-0.04586408659815788,-0.060153499245643616,0.03816533833742142,0.030828779563307762,0.002459413604810834,-0.030086778104305267,-0.048101216554641724,-0.023348495364189148,-0.028618134558200836,0.013447613455355167,-0.0010045123053714633,-0.04441021755337715,-0.050343066453933716,0.0684918537735939,-0.006522811483591795,0.028555244207382202,0.03996872901916504,0.08503507822751999,0.013869824819266796,-0.09197790920734406,0.06859693676233292,0.011526135727763176,-0.03449692949652672,-0.029026776552200317,-0.006731563713401556,0.02893655188381672,0.03418858349323273,0.013723241165280342,0.06560097634792328,-0.028441153466701508,-0.023732632398605347,-0.028864309191703796,-0.025746963918209076,0.04814513027667999,0.05256229266524315,-0.026192842051386833,-0.025821669027209282,-0.04349000006914139,-0.009074469096958637,0.036300789564847946,-0.05020939186215401,0.01080898568034172,0.008669303730130196,0.058075446635484695,-0.03749445080757141,-0.025824975222349167,0.06785323470830917,0.007064898498356342,0.10195902734994888,-0.019523080438375473,0.007326492574065924,-0.01357446238398552,-0.007412927690893412,0.03791226074099541,-0.027214959263801575,0.013223815709352493,0.026358697563409805,-0.02349749393761158,-0.018641194328665733,0.0564635768532753,-0.07462703436613083,-0.0016907289391383529,-0.08673571050167084,-0.004169064108282328,-0.010106505826115608,-0.0017142773140221834,0.04216575250029564,-0.03959576040506363,-0.050706442445516586,-0.023953821510076523,-0.057368580251932144,-0.03070925921201706,0.004805277101695538,-0.02523128315806389,0.005794142838567495,-0.006870611570775509,0.014444926753640175,-0.012345490977168083,-0.06555430591106415,-0.02410455048084259,0.0029828303959220648,0.005982761736959219,-0.0056174928322434425,0.04265100136399269,0.0558980330824852,-0.03698242828249931,0.017037764191627502,0.03674424812197685,-0.06586883217096329,0.026106437668204308,0.04016219824552536,-0.033117055892944336,0.0473032146692276,0.000647311971988529,0.06758075952529907,-0.007591931149363518,0.021174922585487366,-0.007785580586642027,0.003872367786243558,-0.013889883644878864,0.011402812786400318,0.00044741257443092763,-0.02686653845012188,-0.026398198679089546,0.004480898845940828,-0.03714356943964958,-0.03831327706575394,-0.010613461025059223,0.010628024116158485,-0.012937154620885849,0.02901083044707775,0.012769429944455624,0.03902195394039154,0.04641520231962204,-0.029221786186099052,0.04299160838127136,-0.02528604306280613,0.019674785435199738,-0.0345003604888916,-0.028478773310780525,0.02391287498176098,0.016952820122241974,-0.018761951476335526,0.01703314296901226,-0.004881175234913826,0.05159027874469757,-0.015350195579230785,-0.03312986344099045,0.04386473074555397,-0.010108468122780323,0.0010479703778401017,0.028348226100206375,0.0312111247330904,0.0023352340795099735,-0.013986705802381039,0.004979316610842943,-0.04611768200993538,-0.030708685517311096,0.05886486545205116,-0.028504399582743645,-0.023841600865125656,0.012596645392477512,-0.017274921759963036,0.024820733815431595,0.03505658730864525,-0.05192621424794197,0.0108442148193717,0.024518586695194244,-0.07292183488607407,0.023325953632593155,-0.0029717825818806887,-0.036330852657556534,0.008756138384342194,0.052033767104148865,0.00037948627141304314] \ No newline at end of file diff --git a/so_vector/queries.json.bz2 b/so_vector/queries.json.bz2 new file mode 100644 index 000000000..120f1b857 Binary files /dev/null and b/so_vector/queries.json.bz2 differ diff --git a/so_vector/track.py b/so_vector/track.py index e6d2ec552..f8f31bc10 100644 --- a/so_vector/track.py +++ b/so_vector/track.py @@ -1,5 +1,23 @@ +import bz2 import json +import logging import os +from typing import Any, List + +logger = logging.getLogger(__name__) +QUERIES_FILENAME: str = "queries.json.bz2" +TRUE_KNN_FILENAME: str = "queries-recall.json.bz2" +QUERIES_FILENAME_1K: str = "queries-1k.json.bz2" +TRUE_KNN_FILENAME_1K: str = "queries-recall-1k.json.bz2" + + +def compute_percentile(data: List[Any], percentile): + size = len(data) + if size <= 0: + return None + sorted_data = sorted(data) + index = int(round(percentile * size / 100)) - 1 + return sorted_data[max(min(index, size - 1), 0)] class KnnParamSource: @@ -15,18 +33,30 @@ def __init__(self, track, params, **kwargs): self._cache = params.get("cache", False) self._exact_scan = params.get("exact", False) self._params = params + self._queries = [] cwd = os.path.dirname(__file__) - with open(os.path.join(cwd, "queries.json"), "r") as file: - lines = file.readlines() - self._queries = [json.loads(line) for line in lines] + with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file: + for vector_query in queries_file: + self._queries.append(json.loads(vector_query)) self.infinite = True + self._iters = 0 + self._maxIters = len(self._queries) def partition(self, partition_index, total_partitions): return self def params(self): result = {"index": self._index_name, "cache": self._params.get("cache", False), "size": self._params.get("k", 10)} + num_candidates = self._params.get("num_candidates", 50) + # if -1, then its unset. If set, just set it. + oversample = self._params.get("oversample", -1) + if oversample > -1 and self._exact_scan: + raise ValueError("Oversampling is not supported for exact scan queries.") + query_vec = self._queries[self._iters] + self._iters += 1 + if self._iters >= self._maxIters: + self._iters = 0 if self._exact_scan: result["body"] = { @@ -35,7 +65,7 @@ def params(self): "query": {"match_all": {}}, "script": { "source": "dotProduct(params.query, 'titleVector') + 1.0", - "params": {"query": self._queries[0]}, + "params": {"query": query_vec}, }, } }, @@ -47,16 +77,126 @@ def params(self): result["body"] = { "knn": { "field": "titleVector", - "query_vector": self._queries[0], + "query_vector": query_vec, "k": self._params.get("k", 10), - "num_candidates": self._params.get("num-candidates", 50), + "num_candidates": self._params.get("num_candidates", 50), }, "_source": False, } if "filter" in self._params: result["body"]["knn"]["filter"] = self._params["filter"] + if oversample > -1: + result["body"]["knn"]["rescore_vector"] = {"oversample": oversample} + return result +class KnnVectorStore: + def __init__(self): + cwd = os.path.dirname(__file__) + self._query_nearest_neighbor_docids = [] + self._queries = [] + with bz2.open(os.path.join(cwd, TRUE_KNN_FILENAME), "r") as queries_file: + for docids in queries_file: + self._query_nearest_neighbor_docids.append(json.loads(docids)) + with bz2.open(os.path.join(cwd, QUERIES_FILENAME), "r") as queries_file: + for vector_query in queries_file: + self._queries.append(json.loads(vector_query)) + + def get_query_vectors(self) -> List[List[float]]: + return self._queries + + def get_neighbors_for_query(self, query_id: int, size: int) -> List[str]: + if (query_id < 0) or (query_id >= len(self._query_nearest_neighbor_docids)): + raise ValueError(f"Unknown query with id: '{query_id}' provided") + if (size < 0) or (size > len(self._query_nearest_neighbor_docids[query_id])): + raise ValueError(f"Invalid size: '{size}' provided for query with id: '{query_id}'") + return self._query_nearest_neighbor_docids[query_id][:size] + + +class KnnRecallParamSource: + def __init__(self, track, params, **kwargs): + if len(track.indices) == 1: + default_index = track.indices[0].name + else: + default_index = "_all" + + self._index_name = params.get("index", default_index) + self._cache = params.get("cache", False) + self._params = params + self.infinite = True + cwd = os.path.dirname(__file__) + + def partition(self, partition_index, total_partitions): + return self + + def params(self): + return { + "index": self._index_name, + "cache": self._params.get("cache", False), + "size": self._params.get("k", 10), + "num_candidates": self._params.get("num_candidates", 50), + "oversample": self._params.get("oversample", -1), + "knn_vector_store": KnnVectorStore(), + } + + +# Used in tandem with the KnnRecallParamSource. +# reads the queries, executes knn search and compares the results with the true nearest neighbors +class KnnRecallRunner: + def get_knn_query(self, query_vec, k, num_candidates, oversample): + knn = { + "field": "titleVector", + "query_vector": query_vec, + "k": k, + "num_candidates": num_candidates, + } + if oversample > -1: + knn["rescore_vector"] = {"oversample": oversample} + return {"knn": knn, "_source": False, "docvalue_fields": ["questionId"]} + + async def __call__(self, es, params): + k = params["size"] + num_candidates = params["num_candidates"] + index = params["index"] + request_cache = params["cache"] + recall_total = 0 + exact_total = 0 + min_recall = k + max_recall = 0 + + knn_vector_store: KnnVectorStore = params["knn_vector_store"] + for query_id, query_vector in enumerate(knn_vector_store.get_query_vectors()): + knn_body = self.get_knn_query(query_vector, k, num_candidates, params["oversample"]) + knn_result = await es.search( + body=knn_body, + index=index, + request_cache=request_cache, + size=k, + ) + knn_hits = [hit["fields"]["questionId"][0] for hit in knn_result["hits"]["hits"]] + true_neighbors = knn_vector_store.get_neighbors_for_query(query_id, k)[:k] + current_recall = len(set(knn_hits).intersection(set(true_neighbors))) + recall_total += current_recall + exact_total += len(true_neighbors) + min_recall = min(min_recall, current_recall) + max_recall = max(max_recall, current_recall) + to_return = { + "avg_recall": recall_total / exact_total, + "min_recall": min_recall, + "max_recall": max_recall, + "k": k, + "num_candidates": num_candidates, + "oversample": params["oversample"], + } + logger.info(f"Recall results: {to_return}") + return to_return + + def __repr__(self, *args, **kwargs): + return "knn-recall" + + def register(registry): registry.register_param_source("knn-param-source", KnnParamSource) + registry.register_param_source("knn-recall-param-source", KnnRecallParamSource) + registry.register_runner("knn-recall", KnnRecallRunner(), async_runner=True)