Skip to content

Commit

Permalink
fix: some RAG retrieval bugs (#1577)
Browse files Browse the repository at this point in the history
Co-authored-by: Joel <[email protected]>
  • Loading branch information
zxhlyh and iamjoel authored Nov 21, 2023
1 parent d0456d0 commit 6768fd4
Show file tree
Hide file tree
Showing 15 changed files with 267 additions and 106 deletions.
6 changes: 5 additions & 1 deletion web/app/components/app/chat/citation/popup.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ const Popup: FC<PopupProps> = ({
data={source.index_node_hash.substring(0, 7)}
icon={<BezierCurve03 className='mr-1 w-3 h-3' />}
/>
<ProgressTooltip data={Number(source.score.toFixed(2))} />
{
source.score && (
<ProgressTooltip data={Number(source.score.toFixed(2))} />
)
}
</div>
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleValueChange = (type: string, value: string) => {
Expand All @@ -78,6 +79,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
Expand Down Expand Up @@ -270,7 +272,7 @@ const SettingsModal: FC<SettingsModalProps> = ({
)}

<div
className='absolute z-10 bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white '
className='absolute z-[5] bottom-0 w-full flex justify-end py-4 px-6 border-t bg-white '
style={{
borderColor: 'rgba(0, 0, 0, 0.05)',
}}
Expand Down
19 changes: 15 additions & 4 deletions web/app/components/datasets/common/check-rerank-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,29 @@ export const isReRankModelSelected = ({
rerankDefaultModel,
isRerankDefaultModelVaild,
retrievalConfig,
rerankModelList,
indexMethod,
}: {
rerankDefaultModel?: BackendModel
isRerankDefaultModelVaild: boolean
retrievalConfig: RetrievalConfig
rerankModelList: BackendModel[]
indexMethod?: string
}) => {
const rerankModel = (retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined) || (isRerankDefaultModelVaild ? rerankDefaultModel : undefined)
const rerankModelSelected = (() => {
if (retrievalConfig.reranking_model?.reranking_model_name)
return !!rerankModelList.find(({ model_name }) => model_name === retrievalConfig.reranking_model?.reranking_model_name)

if (isRerankDefaultModelVaild)
return !!rerankDefaultModel

return false
})()

if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& !rerankModel
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModelSelected
)
return false

Expand All @@ -35,7 +46,7 @@ export const ensureRerankModelSelected = ({
const rerankModel = retrievalConfig.reranking_model?.reranking_model_name ? retrievalConfig.reranking_model : undefined
if (
indexMethod === 'high_quality'
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.fullText)
&& (retrievalConfig.reranking_enable || retrievalConfig.search_method === RETRIEVE_METHOD.hybrid)
&& !rerankModel
) {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,23 @@ type Props = {
}

const RetrievalMethodConfig: FC<Props> = ({
value,
value: passValue,
onChange,
}) => {
const { t } = useTranslation()
const { supportRetrievalMethods } = useProviderContext()
const { supportRetrievalMethods, rerankDefaultModel } = useProviderContext()
const value = (() => {
if (!passValue.reranking_model.reranking_model_name) {
return {
...passValue,
reranking_model: {
reranking_provider_name: rerankDefaultModel?.model_provider.provider_name || '',
reranking_model_name: rerankDefaultModel?.model_name || '',
},
}
}
return passValue
})()
return (
<div className='space-y-2'>
{supportRetrievalMethods.includes(RETRIEVE_METHOD.semantic) && (
Expand Down
5 changes: 5 additions & 0 deletions web/app/components/datasets/create/step-two/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ const StepTwo = ({
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()
const getCreationParams = () => {
let params
Expand All @@ -282,6 +283,7 @@ const StepTwo = ({
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
// eslint-disable-next-line @typescript-eslint/no-use-before-define
retrievalConfig,
indexMethod: indexMethod as string,
Expand Down Expand Up @@ -359,6 +361,9 @@ const StepTwo = ({
try {
let res
const params = getCreationParams()
if (!params)
return false

setIsCreating(true)
if (!datasetId) {
res = await createFirstDocument({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@ import type { FC } from 'react'
import React, { useRef, useState } from 'react'
import { useClickAway } from 'ahooks'
import { useTranslation } from 'react-i18next'
import Toast from '../../base/toast'
import { XClose } from '@/app/components/base/icons/src/vender/line/general'
import type { RetrievalConfig } from '@/types/app'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
import Button from '@/app/components/base/button'
import { useProviderContext } from '@/context/provider-context'
import { ensureRerankModelSelected, isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model'

type Props = {
indexMethod: string
Expand All @@ -33,6 +36,32 @@ const ModifyRetrievalModal: FC<Props> = ({
onHide()
}, ref)

const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleSave = () => {
if (
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
) {
Toast.notify({ type: 'error', message: t('appDebug.datasetConfig.rerankModelRequired') })
return
}
onSave(ensureRerankModelSelected({
rerankDefaultModel: rerankDefaultModel!,
retrievalConfig,
indexMethod,
}))
}

if (!isShow)
return null

Expand Down Expand Up @@ -87,7 +116,7 @@ const ModifyRetrievalModal: FC<Props> = ({
}}
>
<Button className='mr-2 flex-shrink-0' onClick={onHide}>{t('common.operation.cancel')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={() => onSave(retrievalConfig)} >{t('common.operation.save')}</Button>
<Button type='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
</div>
</div>
)
Expand Down
2 changes: 2 additions & 0 deletions web/app/components/datasets/settings/form/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ const Form = () => {
const {
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
} = useProviderContext()

const handleSave = async () => {
Expand All @@ -72,6 +73,7 @@ const Form = () => {
!isReRankModelSelected({
rerankDefaultModel,
isRerankDefaultModelVaild,
rerankModelList,
retrievalConfig,
indexMethod,
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ const config: ProviderConfig = {
'en': <CohereText className='w-[120px] h-6' />,
'zh-Hans': <CohereText className='w-[120px] h-6' />,
},
hit: {
'en': 'Rerank Model Supported',
'zh-Hans': '支持 Rerank 模型',
},
},
modal: {
key: ProviderEnum.cohere,
title: {
'en': 'cohere',
'zh-Hans': 'cohere',
'en': 'Rerank Model',
'zh-Hans': 'Rerank 模型',
},
icon: <Cohere className='w-6 h-6' />,
link: {
Expand Down
22 changes: 19 additions & 3 deletions web/app/components/header/account-setting/model-page/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import { ModelType } from '@/app/components/header/account-setting/model-page/de
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { useProviderContext } from '@/context/provider-context'
import I18n from '@/context/i18n'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'

const MODEL_CARD_LIST = [
config.openai,
Expand All @@ -42,6 +43,10 @@ const ModelPage = () => {
const { locale } = useContext(I18n)
const {
updateModelList,
textGenerationDefaultModel,
embeddingsDefaultModel,
speech2textDefaultModel,
rerankDefaultModel,
} = useProviderContext()
const { data: providers, mutate: mutateProviders } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
const [showModal, setShowModal] = useState(false)
Expand Down Expand Up @@ -196,11 +201,22 @@ const ModelPage = () => {
}
}

const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel

return (
<div className='relative pt-1 -mt-2'>
<div className='flex items-center justify-between mb-2 h-8'>
<div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div>
<SystemModel />
<div className={`flex items-center justify-between mb-2 h-8 ${defaultModelNotConfigured && 'px-3 bg-[#FFFAEB] rounded-lg border border-[#FEF0C7]'}`}>
{
defaultModelNotConfigured
? (
<div className='flex items-center text-xs font-medium text-gray-700'>
<AlertTriangle className='mr-1 w-3 h-3 text-[#F79009]' />
{t('common.modelProvider.notConfigured')}
</div>
)
: <div className='text-sm font-medium text-gray-800'>{t('common.modelProvider.models')}</div>
}
<SystemModel onUpdate={() => mutateProviders()} />
</div>
<div className='grid grid-cols-2 gap-4 mb-6'>
{
Expand Down
Loading

0 comments on commit 6768fd4

Please sign in to comment.