Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/add retriever rerank #1560

Merged
merged 24 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions api/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import uuid

import click
import qdrant_client
from qdrant_client.http.models import TextIndexParams, TextIndexType, TokenizerType
from tqdm import tqdm
from flask import current_app, Flask
from langchain.embeddings import OpenAIEmbeddings
Expand Down Expand Up @@ -484,6 +486,38 @@ def normalization_collections():
click.echo(click.style('Congratulations! restore {} dataset indexes.'.format(len(normalization_count)), fg='green'))


@click.command('add-qdrant-full-text-index', help='add qdrant full text index')
def add_qdrant_full_text_index():
click.echo(click.style('Start add full text index.', fg='green'))
binds = db.session.query(DatasetCollectionBinding).all()
if binds and current_app.config['VECTOR_STORE'] == 'qdrant':
qdrant_url = current_app.config['QDRANT_URL']
qdrant_api_key = current_app.config['QDRANT_API_KEY']
client = qdrant_client.QdrantClient(
qdrant_url,
api_key=qdrant_api_key, # For Qdrant Cloud, None for local instance
)
for bind in binds:
try:
text_index_params = TextIndexParams(
type=TextIndexType.TEXT,
tokenizer=TokenizerType.MULTILINGUAL,
min_token_len=2,
max_token_len=20,
lowercase=True
)
client.create_payload_index(bind.collection_name, 'page_content',
field_schema=text_index_params)
except Exception as e:
click.echo(
click.style('Create full text index error: {} {}'.format(e.__class__.__name__, str(e)),
fg='red'))
click.echo(
click.style(
'Congratulations! add collection {} full text index successful.'.format(bind.collection_name),
fg='green'))


def deal_dataset_vector(flask_app: Flask, dataset: Dataset, normalization_count: list):
with flask_app.app_context():
try:
Expand Down Expand Up @@ -647,24 +681,24 @@ def update_app_model_configs(batch_size):

pbar.update(len(data_batch))


@click.command('migrate_default_input_to_dataset_query_variable')
@click.option("--batch-size", default=500, help="Number of records to migrate in each batch.")
def migrate_default_input_to_dataset_query_variable(batch_size):

click.secho("Starting...", fg='green')

total_records = db.session.query(AppModelConfig) \
.join(App, App.app_model_config_id == AppModelConfig.id) \
.filter(App.mode == 'completion') \
.filter(AppModelConfig.dataset_query_variable == None) \
.count()

if total_records == 0:
click.secho("No data to migrate.", fg='green')
return

num_batches = (total_records + batch_size - 1) // batch_size

with tqdm(total=total_records, desc="Migrating Data") as pbar:
for i in range(num_batches):
offset = i * batch_size
Expand Down Expand Up @@ -697,22 +731,22 @@ def migrate_default_input_to_dataset_query_variable(batch_size):
for form in user_input_form:
paragraph = form.get('paragraph')
if paragraph \
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break
and paragraph.get('variable') == 'query':
data.dataset_query_variable = 'query'
break

if paragraph \
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break
and paragraph.get('variable') == 'default_input':
data.dataset_query_variable = 'default_input'
break

db.session.commit()

except Exception as e:
click.secho(f"Error while migrating data: {e}, app_id: {data.app_id}, app_model_config_id: {data.id}",
fg='red')
continue

click.secho(f"Successfully migrated batch {i + 1}/{num_batches}.", fg='green')

pbar.update(len(data_batch))
Expand All @@ -731,3 +765,4 @@ def register_commands(app):
app.cli.add_command(update_app_model_configs)
app.cli.add_command(normalization_collections)
app.cli.add_command(migrate_default_input_to_dataset_query_variable)
app.cli.add_command(add_qdrant_full_text_index)
48 changes: 48 additions & 0 deletions api/controllers/console/datasets/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def patch(self, dataset_id):
help='Invalid indexing technique.')
parser.add_argument('permission', type=str, location='json', choices=(
'only_me', 'all_team_members'), help='Invalid permission.')
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.')
args = parser.parse_args()

# The role of the current user in the ta table must be admin or owner
Expand Down Expand Up @@ -401,6 +402,7 @@ def post(self):

class DatasetApiDeleteApi(Resource):
resource_type = 'dataset'

@setup_required
@login_required
@account_initialization_required
Expand Down Expand Up @@ -436,6 +438,50 @@ def get(self):
}


class DatasetRetrievalSettingApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
vector_type = current_app.config['VECTOR_STORE']
if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")


class DatasetRetrievalSettingMockApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, vector_type):

if vector_type == 'milvus':
return {
'retrieval_method': [
'semantic_search'
]
}
elif vector_type == 'qdrant' or vector_type == 'weaviate':
return {
'retrieval_method': [
'semantic_search', 'full_text_search', 'hybrid_search'
]
}
else:
raise ValueError("Unsupported vector db type.")


api.add_resource(DatasetListApi, '/datasets')
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>')
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries')
Expand All @@ -445,3 +491,5 @@ def get(self):
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys')
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>')
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info')
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting')
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>')
4 changes: 4 additions & 0 deletions api/controllers/console/datasets/datasets_document.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ def post(self, dataset_id):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()

if not dataset.indexing_technique and not args['indexing_technique']:
Expand Down Expand Up @@ -263,6 +265,8 @@ def post(self):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
if args['indexing_technique'] == 'high_quality':
try:
Expand Down
11 changes: 5 additions & 6 deletions api/controllers/console/datasets/hit_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,18 @@ def post(self, dataset_id):

parser = reqparse.RequestParser()
parser.add_argument('query', type=str, location='json')
parser.add_argument('retrieval_model', type=dict, required=False, location='json')
args = parser.parse_args()

query = args['query']

if not query or len(query) > 250:
raise ValueError('Query is required and cannot exceed 250 characters')
HitTestingService.hit_testing_args_check(args)

try:
response = HitTestingService.retrieve(
dataset=dataset,
query=query,
query=args['query'],
account=current_user,
limit=10,
retrieval_model=args['retrieval_model'],
limit=10
)

return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)}
Expand Down
21 changes: 10 additions & 11 deletions api/controllers/console/workspace/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DefaultModelApi(Resource):
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='args')
choices=['text-generation', 'embeddings', 'speech2text', 'reranking'], location='args')
args = parser.parse_args()

tenant_id = current_user.current_tenant_id
Expand Down Expand Up @@ -71,19 +71,18 @@ def get(self):
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('model_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_type', type=str, required=True, nullable=False,
choices=['text-generation', 'embeddings', 'speech2text'], location='json')
parser.add_argument('provider_name', type=str, required=True, nullable=False, location='json')
parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json')
args = parser.parse_args()

provider_service = ProviderService()
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=args['model_type'],
provider_name=args['provider_name'],
model_name=args['model_name']
)
model_settings = args['model_settings']
for model_setting in model_settings:
provider_service.update_default_model_of_model_type(
tenant_id=current_user.current_tenant_id,
model_type=model_setting['model_type'],
provider_name=model_setting['provider_name'],
model_name=model_setting['model_name']
)

return {'result': 'success'}

Expand Down
4 changes: 4 additions & 0 deletions api/controllers/service_api/dataset/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def post(self, tenant_id, dataset_id):
location='json')
parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand Down Expand Up @@ -95,6 +97,8 @@ def post(self, tenant_id, dataset_id, document_id):
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json')
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False,
location='json')
parser.add_argument('retrieval_model', type=dict, required=False, nullable=False,
location='json')
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
Expand Down
2 changes: 0 additions & 2 deletions api/core/agent/agent/multi_dataset_router_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool


class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
Expand Down Expand Up @@ -60,7 +59,6 @@ def plan(
return AgentFinish(return_values={"output": ''}, log='')
elif len(self.tools) == 1:
tool = next(iter(self.tools))
tool = cast(DatasetRetrieverTool, tool)
rst = tool.run(tool_input={'query': kwargs['input']})
# output = ''
# rst_json = json.loads(rst)
Expand Down
Loading