Skip to content

Commit 084d6d1

Browse files
authored
Feat/add retriever rerank (langgenius#1560)
Co-authored-by: jyong <[email protected]>
1 parent fb2da40 commit 084d6d1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1903
-168
lines changed

Diff for: api/commands.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import uuid
99

1010
import click
11+
import qdrant_client
12+
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
1113
from tqdm import tqdm
1214
from flask import current_app, Flask
1315
from langchain.embeddings import OpenAIEmbeddings
@@ -484,6 +486,38 @@ def normalization_collections():
484486
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))
485487

486488

489+
@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
490+
def add_qdrant_full_text_index():
491+
click.echo(click.style('Start add full text index.', fg='green'))
492+
binds = db.session.query(DatasetCollectionBinding).all()
493+
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
494+
qdrant_url = current_app.config['QDRANT_URL']
495+
qdrant_api_key = current_app.config['QDRANT_API_KEY']
496+
client = qdrant_client.QdrantClient(
497+
qdrant_url,
498+
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
499+
)
500+
for bind in binds:
501+
try:
502+
text_index_params = TextIndexParams(
503+
type=TextIndexType.TEXT,
504+
tokenizer=TokenizerType.MULTILINGUAL,
505+
min_token_len=2,
506+
max_token_len=20,
507+
lowercase=True
508+
)
509+
client.create_payload_index(bind.collection_name, 'page_content',
510+
field_schema=text_index_params)
511+
except Exception as e:
512+
click.echo(
513+
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
514+
fg='red'))
515+
click.echo(
516+
click.style(
517+
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
518+
fg='green'))
519+
520+
487521
def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
488522
with flask_app.app_context():
489523
try:
@@ -647,24 +681,24 @@ def update_app_model_configs(batch_size):
647681

648682
pbar.update(len(data_batch))
649683

684+
650685
@click.command('migrate_default_input_to_dataset_query_variable')
651686
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
652687
def migrate_default_input_to_dataset_query_variable(batch_size):
653-
654688
click.secho("Starting...", fg='green')
655689

656690
total_records = db.session.query(AppModelConfig) \
657691
.join(App, App.app_model_config_id == AppModelConfig.id) \
658692
.filter(App.mode == 'completion') \
659693
.filter(AppModelConfig.dataset_query_variable == None) \
660694
.count()
661-
695+
662696
if total_records == 0:
663697
click.secho("No data to migrate.", fg='green')
664698
return
665699

666700
num_batches = (total_records + batch_size - 1) // batch_size
667-
701+
668702
with tqdm(total=total_records, desc="Migrating Data") as pbar:
669703
for i in range(num_batches):
670704
offset = i * batch_size
@@ -697,22 +731,22 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
697731
for form in user_input_form:
698732
paragraph = form.get('paragraph')
699733
if paragraph \
700-
and paragraph.get('variable') == 'query':
701-
data.dataset_query_variable = 'query'
702-
break
703-
734+
and paragraph.get('variable') == 'query':
735+
data.dataset_query_variable = 'query'
736+
break
737+
704738
if paragraph \
705-
and paragraph.get('variable') == 'default_input':
706-
data.dataset_query_variable = 'default_input'
707-
break
739+
and paragraph.get('variable') == 'default_input':
740+
data.dataset_query_variable = 'default_input'
741+
break
708742

709743
db.session.commit()
710744

711745
except Exception as e:
712746
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
713747
fg='red')
714748
continue
715-
749+
716750
click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')
717751

718752
pbar.update(len(data_batch))
@@ -731,3 +765,4 @@ def register_commands(app):
731765
app.cli.add_command(update_app_model_configs)
732766
app.cli.add_command(normalization_collections)
733767
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
768+
app.cli.add_command(add_qdrant_full_text_index)

Diff for: api/controllers/console/datasets/datasets.py

+48
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def patch(self, dataset_id):
170170
help='Invalid indexing technique.')
171171
parser.add_argument('permission', type=str, location='json', choices=(
172172
'only_me', 'all_team_members'), help='Invalid permission.')
173+
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
173174
args = parser.parse_args()
174175

175176
# The role of the current user in the ta table must be admin or owner
@@ -401,6 +402,7 @@ def post(self):
401402

402403
class DatasetApiDeleteApi(Resource):
403404
resource_type = 'dataset'
405+
404406
@setup_required
405407
@login_required
406408
@account_initialization_required
@@ -436,6 +438,50 @@ def get(self):
436438
}
437439

438440

441+
class DatasetRetrievalSettingApi(Resource):
442+
@setup_required
443+
@login_required
444+
@account_initialization_required
445+
def get(self):
446+
vector_type = current_app.config['VECTOR_STORE']
447+
if vector_type == 'milvus':
448+
return {
449+
'retrieval_method': [
450+
'semantic_search'
451+
]
452+
}
453+
elif vector_type == 'qdrant' or vector_type == 'weaviate':
454+
return {
455+
'retrieval_method': [
456+
'semantic_search', 'full_text_search', 'hybrid_search'
457+
]
458+
}
459+
else:
460+
raise ValueError("Unsupported vector db type.")
461+
462+
463+
class DatasetRetrievalSettingMockApi(Resource):
464+
@setup_required
465+
@login_required
466+
@account_initialization_required
467+
def get(self, vector_type):
468+
469+
if vector_type == 'milvus':
470+
return {
471+
'retrieval_method': [
472+
'semantic_search'
473+
]
474+
}
475+
elif vector_type == 'qdrant' or vector_type == 'weaviate':
476+
return {
477+
'retrieval_method': [
478+
'semantic_search', 'full_text_search', 'hybrid_search'
479+
]
480+
}
481+
else:
482+
raise ValueError("Unsupported vector db type.")
483+
484+
439485
api.add_resource(DatasetListApi, '/datasets')
440486
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
441487
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
@@ -445,3 +491,5 @@ def get(self):
445491
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
446492
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
447493
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
494+
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
495+
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')

Diff for: api/controllers/console/datasets/datasets_document.py

+4
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def post(self, dataset_id):
221221
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
222222
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
223223
location='json')
224+
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
225+
location='json')
224226
args = parser.parse_args()
225227

226228
if not dataset.indexing_technique and not args['indexing_technique']:
@@ -263,6 +265,8 @@ def post(self):
263265
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
264266
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
265267
location='json')
268+
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
269+
location='json')
266270
args = parser.parse_args()
267271
if args['indexing_technique'] == 'high_quality':
268272
try:

Diff for: api/controllers/console/datasets/hit_testing.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,18 @@ def post(self, dataset_id):
4242

4343
parser = reqparse.RequestParser()
4444
parser.add_argument('query', type=str, location='json')
45+
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
4546
args = parser.parse_args()
4647

47-
query = args['query']
48-
49-
if not query or len(query) > 250:
50-
raise ValueError('Query is required and cannot exceed 250 characters')
48+
HitTestingService.hit_testing_args_check(args)
5149

5250
try:
5351
response = HitTestingService.retrieve(
5452
dataset=dataset,
55-
query=query,
53+
query=args['query'],
5654
account=current_user,
57-
limit=10,
55+
retrieval_model=args['retrieval_model'],
56+
limit=10
5857
)
5958

6059
return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}

Diff for: api/controllers/console/workspace/models.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
1919
def get(self):
2020
parser = reqparse.RequestParser()
2121
parser.add_argument('model_type', type=str, required=True, nullable=False,
22-
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
22+
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
2323
args = parser.parse_args()
2424

2525
tenant_id = current_user.current_tenant_id
@@ -71,19 +71,18 @@ def get(self):
7171
@account_initialization_required
7272
def post(self):
7373
parser = reqparse.RequestParser()
74-
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
75-
parser.add_argument('model_type', type=str, required=True, nullable=False,
76-
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
77-
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
74+
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
7875
args = parser.parse_args()
7976

8077
provider_service = ProviderService()
81-
provider_service.update_default_model_of_model_type(
82-
tenant_id=current_user.current_tenant_id,
83-
model_type=args['model_type'],
84-
provider_name=args['provider_name'],
85-
model_name=args['model_name']
86-
)
78+
model_settings = args['model_settings']
79+
for model_setting in model_settings:
80+
provider_service.update_default_model_of_model_type(
81+
tenant_id=current_user.current_tenant_id,
82+
model_type=model_setting['model_type'],
83+
provider_name=model_setting['provider_name'],
84+
model_name=model_setting['model_name']
85+
)
8786

8887
return {'result': 'success'}
8988

Diff for: api/controllers/service_api/dataset/document.py

+4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def post(self, tenant_id, dataset_id):
3636
location='json')
3737
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
3838
location='json')
39+
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
40+
location='json')
3941
args = parser.parse_args()
4042
dataset_id = str(dataset_id)
4143
tenant_id = str(tenant_id)
@@ -95,6 +97,8 @@ def post(self, tenant_id, dataset_id, document_id):
9597
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
9698
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
9799
location='json')
100+
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
101+
location='json')
98102
args = parser.parse_args()
99103
dataset_id = str(dataset_id)
100104
tenant_id = str(tenant_id)

Diff for: api/core/agent/agent/multi_dataset_router_agent.py

-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from core.model_providers.models.entity.message import to_prompt_messages
1515
from core.model_providers.models.llm.base import BaseLLM
1616
from core.third_party.langchain.llms.fake import FakeLLM
17-
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
1817

1918

2019
class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@@ -60,7 +59,6 @@ def plan(
6059
return AgentFinish(return_values={"output": ''}, log='')
6160
elif len(self.tools) == 1:
6261
tool = next(iter(self.tools))
63-
tool = cast(DatasetRetrieverTool, tool)
6462
rst = tool.run(tool_input={'query': kwargs['input']})
6563
# output = ''
6664
# rst_json = json.loads(rst)

0 commit comments

Comments
 (0)