diff --git a/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js b/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js new file mode 100644 index 0000000000..5c06d6ddf5 --- /dev/null +++ b/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js @@ -0,0 +1,24 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table + .jsonb('adjustment') + .nullable() + .comment( + 'Adjustment data for thread response, including type and payload', + ); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table.dropColumn('adjustment'); + }); +}; diff --git a/wren-ui/src/apollo/client/graphql/__types__.ts b/wren-ui/src/apollo/client/graphql/__types__.ts index 36b7959de6..18454a756e 100644 --- a/wren-ui/src/apollo/client/graphql/__types__.ts +++ b/wren-ui/src/apollo/client/graphql/__types__.ts @@ -23,6 +23,21 @@ export type AdjustThreadResponseChartInput = { yAxis?: InputMaybe; }; +export type AdjustThreadResponseInput = { + sql?: InputMaybe; + sqlGenerationReasoning?: InputMaybe; + tables?: InputMaybe>; +}; + +export type AdjustmentTask = { + __typename?: 'AdjustmentTask'; + error?: Maybe; + queryId?: Maybe; + sql?: Maybe; + status?: Maybe; + traceId?: Maybe; +}; + export type AskingTask = { __typename?: 'AskingTask'; candidates: Array; @@ -547,7 +562,9 @@ export type ModelWhereInput = { export type Mutation = { __typename?: 'Mutation'; + adjustThreadResponse: ThreadResponse; adjustThreadResponseChart: ThreadResponse; + cancelAdjustmentTask: Scalars['Boolean']; cancelAskingTask: Scalars['Boolean']; createAskingTask: Task; createCalculatedField: Scalars['JSON']; @@ -581,6 +598,7 @@ export type Mutation = { previewModelData: Scalars['JSON']; previewSql: Scalars['JSON']; previewViewData: Scalars['JSON']; + rerunAdjustmentTask: Scalars['Boolean']; rerunAskingTask: Task; resetCurrentProject: Scalars['Boolean']; resolveSchemaChange: Scalars['Boolean']; @@ -607,12 +625,23 @@ export type Mutation = { }; +export type MutationAdjustThreadResponseArgs = { + data: AdjustThreadResponseInput; + responseId: Scalars['Int']; +}; + + export type MutationAdjustThreadResponseChartArgs = { data: AdjustThreadResponseChartInput; responseId: Scalars['Int']; }; +export type MutationCancelAdjustmentTaskArgs = { + taskId: Scalars['String']; +}; + + export type MutationCancelAskingTaskArgs = { taskId: Scalars['String']; }; @@ -774,6 +803,11 @@ export type MutationPreviewViewDataArgs = { }; +export type MutationRerunAdjustmentTaskArgs = { + responseId: Scalars['Int']; +}; + + export type MutationRerunAskingTaskArgs = { responseId: Scalars['Int']; }; @@ -933,7 +967,7 @@ export type PreviewItemSqlInput = { export type PreviewSqlDataInput = { dryRun?: InputMaybe; limit?: InputMaybe; - projectId?: InputMaybe; + projectId?: InputMaybe; sql: Scalars['String']; }; @@ -957,6 +991,7 @@ export enum ProjectLanguage { export type Query = { __typename?: 'Query'; + adjustmentTask?: Maybe; askingTask?: Maybe; autoGenerateRelation: Array; dashboardItems: Array; @@ -985,6 +1020,11 @@ export type Query = { }; +export type QueryAdjustmentTaskArgs = { + taskId: Scalars['String']; +}; + + export type QueryAskingTaskArgs = { taskId: Scalars['String']; }; @@ -1200,6 +1240,8 @@ export type Thread = { export type ThreadResponse = { __typename?: 'ThreadResponse'; + adjustment?: Maybe; + adjustmentTask?: Maybe; answerDetail?: Maybe; askingTask?: Maybe; breakdownDetail?: Maybe; @@ -1211,6 +1253,17 @@ export type ThreadResponse = { view?: Maybe; }; +export type ThreadResponseAdjustment = { + __typename?: 'ThreadResponseAdjustment'; + payload?: Maybe; + type: ThreadResponseAdjustmentType; +}; + +export enum ThreadResponseAdjustmentType { + APPLY_SQL = 'APPLY_SQL', + REASONING = 'REASONING' +} + export type ThreadResponseAnswerDetail = { __typename?: 'ThreadResponseAnswerDetail'; content?: Maybe; diff --git a/wren-ui/src/apollo/client/graphql/home.generated.ts b/wren-ui/src/apollo/client/graphql/home.generated.ts index d9ece92c47..95a9093d6f 100644 --- a/wren-ui/src/apollo/client/graphql/home.generated.ts +++ b/wren-ui/src/apollo/client/graphql/home.generated.ts @@ -13,7 +13,7 @@ export type CommonChartDetailFragment = { __typename?: 'ThreadResponseChartDetai export type CommonAskingTaskFragment = { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; -export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }; +export type CommonResponseFragment = { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }; export type CommonRecommendedQuestionsTaskFragment = { __typename?: 'RecommendedQuestionsTask', status: Types.RecommendedQuestionsTaskStatus, questions: Array<{ __typename?: 'ResultQuestion', question: string, category: string, sql: string }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null }; @@ -39,14 +39,14 @@ export type ThreadQueryVariables = Types.Exact<{ }>; -export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; +export type ThreadQuery = { __typename?: 'Query', thread: { __typename?: 'DetailedThread', id: number, responses: Array<{ __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }> } }; export type ThreadResponseQueryVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type ThreadResponseQuery = { __typename?: 'Query', threadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type CreateAskingTaskMutationVariables = Types.Exact<{ data: Types.AskingTaskInput; @@ -82,7 +82,7 @@ export type CreateThreadResponseMutationVariables = Types.Exact<{ }>; -export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type CreateThreadResponseMutation = { __typename?: 'Mutation', createThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type UpdateThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -98,7 +98,15 @@ export type UpdateThreadResponseMutationVariables = Types.Exact<{ }>; -export type UpdateThreadResponseMutation = { __typename?: 'Mutation', updateThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type UpdateThreadResponseMutation = { __typename?: 'Mutation', updateThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; + +export type AdjustThreadResponseMutationVariables = Types.Exact<{ + responseId: Types.Scalars['Int']; + data: Types.AdjustThreadResponseInput; +}>; + + +export type AdjustThreadResponseMutation = { __typename?: 'Mutation', adjustThreadResponse: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; export type DeleteThreadMutationVariables = Types.Exact<{ where: Types.ThreadUniqueWhereInput; @@ -166,34 +174,48 @@ export type GenerateThreadRecommendationQuestionsMutationVariables = Types.Exact export type GenerateThreadRecommendationQuestionsMutation = { __typename?: 'Mutation', generateThreadRecommendationQuestions: boolean }; -export type GenerateThreadResponseBreakdownMutationVariables = Types.Exact<{ +export type GenerateThreadResponseAnswerMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type GenerateThreadResponseBreakdownMutation = { __typename?: 'Mutation', generateThreadResponseBreakdown: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type GenerateThreadResponseAnswerMutation = { __typename?: 'Mutation', generateThreadResponseAnswer: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; -export type GenerateThreadResponseAnswerMutationVariables = Types.Exact<{ +export type GenerateThreadResponseChartMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; }>; -export type GenerateThreadResponseAnswerMutation = { __typename?: 'Mutation', generateThreadResponseAnswer: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type GenerateThreadResponseChartMutation = { __typename?: 'Mutation', generateThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; -export type GenerateThreadResponseChartMutationVariables = Types.Exact<{ +export type AdjustThreadResponseChartMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; + data: Types.AdjustThreadResponseChartInput; }>; -export type GenerateThreadResponseChartMutation = { __typename?: 'Mutation', generateThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type AdjustThreadResponseChartMutation = { __typename?: 'Mutation', adjustThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, adjustment?: { __typename?: 'ThreadResponseAdjustment', type: Types.ThreadResponseAdjustmentType, payload?: any | null } | null, adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; -export type AdjustThreadResponseChartMutationVariables = Types.Exact<{ +export type AdjustmentTaskQueryVariables = Types.Exact<{ + taskId: Types.Scalars['String']; +}>; + + +export type AdjustmentTaskQuery = { __typename?: 'Query', adjustmentTask?: { __typename?: 'AdjustmentTask', queryId?: string | null, status?: Types.AskingTaskStatus | null, sql?: string | null, traceId?: string | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null }; + +export type CancelAdjustmentTaskMutationVariables = Types.Exact<{ + taskId: Types.Scalars['String']; +}>; + + +export type CancelAdjustmentTaskMutation = { __typename?: 'Mutation', cancelAdjustmentTask: boolean }; + +export type RerunAdjustmentTaskMutationVariables = Types.Exact<{ responseId: Types.Scalars['Int']; - data: Types.AdjustThreadResponseChartInput; }>; -export type AdjustThreadResponseChartMutation = { __typename?: 'Mutation', adjustThreadResponseChart: { __typename?: 'ThreadResponse', id: number, threadId: number, question: string, sql?: string | null, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, breakdownDetail?: { __typename?: 'ThreadResponseBreakdownDetail', queryId?: string | null, status: Types.AskingTaskStatus, description?: string | null, steps?: Array<{ __typename?: 'DetailStep', summary: string, sql: string, cteName?: string | null }> | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, answerDetail?: { __typename?: 'ThreadResponseAnswerDetail', queryId?: string | null, status?: Types.ThreadResponseAnswerStatus | null, content?: string | null, numRowsUsedInLLM?: number | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, chartDetail?: { __typename?: 'ThreadResponseChartDetail', queryId?: string | null, status: Types.ChartTaskStatus, description?: string | null, chartType?: Types.ChartType | null, chartSchema?: any | null, adjustment?: boolean | null, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null, askingTask?: { __typename?: 'AskingTask', status: Types.AskingTaskStatus, type?: Types.AskingTaskType | null, rephrasedQuestion?: string | null, intentReasoning?: string | null, sqlGenerationReasoning?: string | null, retrievedTables?: Array | null, invalidSql?: string | null, traceId?: string | null, queryId?: string | null, candidates: Array<{ __typename?: 'ResultCandidate', sql: string, type: Types.ResultCandidateType, view?: { __typename?: 'ViewInfo', id: number, name: string, statement: string, displayName: string } | null, sqlPair?: { __typename?: 'SqlPair', id: number, question: string, sql: string, projectId: number } | null }>, error?: { __typename?: 'Error', code?: string | null, shortMessage?: string | null, message?: string | null, stacktrace?: Array | null } | null } | null } }; +export type RerunAdjustmentTaskMutation = { __typename?: 'Mutation', rerunAdjustmentTask: boolean }; export const CommonErrorFragmentDoc = gql` fragment CommonError on Error { @@ -298,11 +320,25 @@ export const CommonResponseFragmentDoc = gql` askingTask { ...CommonAskingTask } + adjustment { + type + payload + } + adjustmentTask { + queryId + status + error { + ...CommonError + } + sql + traceId + } } ${CommonBreakdownDetailFragmentDoc} ${CommonAnswerDetailFragmentDoc} ${CommonChartDetailFragmentDoc} -${CommonAskingTaskFragmentDoc}`; +${CommonAskingTaskFragmentDoc} +${CommonErrorFragmentDoc}`; export const CommonRecommendedQuestionsTaskFragmentDoc = gql` fragment CommonRecommendedQuestionsTask on RecommendedQuestionsTask { status @@ -729,6 +765,40 @@ export function useUpdateThreadResponseMutation(baseOptions?: Apollo.MutationHoo export type UpdateThreadResponseMutationHookResult = ReturnType; export type UpdateThreadResponseMutationResult = Apollo.MutationResult; export type UpdateThreadResponseMutationOptions = Apollo.BaseMutationOptions; +export const AdjustThreadResponseDocument = gql` + mutation AdjustThreadResponse($responseId: Int!, $data: AdjustThreadResponseInput!) { + adjustThreadResponse(responseId: $responseId, data: $data) { + ...CommonResponse + } +} + ${CommonResponseFragmentDoc}`; +export type AdjustThreadResponseMutationFn = Apollo.MutationFunction; + +/** + * __useAdjustThreadResponseMutation__ + * + * To run a mutation, you first call `useAdjustThreadResponseMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useAdjustThreadResponseMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [adjustThreadResponseMutation, { data, loading, error }] = useAdjustThreadResponseMutation({ + * variables: { + * responseId: // value for 'responseId' + * data: // value for 'data' + * }, + * }); + */ +export function useAdjustThreadResponseMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(AdjustThreadResponseDocument, options); + } +export type AdjustThreadResponseMutationHookResult = ReturnType; +export type AdjustThreadResponseMutationResult = Apollo.MutationResult; +export type AdjustThreadResponseMutationOptions = Apollo.BaseMutationOptions; export const DeleteThreadDocument = gql` mutation DeleteThread($where: ThreadUniqueWhereInput!) { deleteThread(where: $where) @@ -1053,39 +1123,6 @@ export function useGenerateThreadRecommendationQuestionsMutation(baseOptions?: A export type GenerateThreadRecommendationQuestionsMutationHookResult = ReturnType; export type GenerateThreadRecommendationQuestionsMutationResult = Apollo.MutationResult; export type GenerateThreadRecommendationQuestionsMutationOptions = Apollo.BaseMutationOptions; -export const GenerateThreadResponseBreakdownDocument = gql` - mutation GenerateThreadResponseBreakdown($responseId: Int!) { - generateThreadResponseBreakdown(responseId: $responseId) { - ...CommonResponse - } -} - ${CommonResponseFragmentDoc}`; -export type GenerateThreadResponseBreakdownMutationFn = Apollo.MutationFunction; - -/** - * __useGenerateThreadResponseBreakdownMutation__ - * - * To run a mutation, you first call `useGenerateThreadResponseBreakdownMutation` within a React component and pass it any options that fit your needs. - * When your component renders, `useGenerateThreadResponseBreakdownMutation` returns a tuple that includes: - * - A mutate function that you can call at any time to execute the mutation - * - An object with fields that represent the current status of the mutation's execution - * - * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; - * - * @example - * const [generateThreadResponseBreakdownMutation, { data, loading, error }] = useGenerateThreadResponseBreakdownMutation({ - * variables: { - * responseId: // value for 'responseId' - * }, - * }); - */ -export function useGenerateThreadResponseBreakdownMutation(baseOptions?: Apollo.MutationHookOptions) { - const options = {...defaultOptions, ...baseOptions} - return Apollo.useMutation(GenerateThreadResponseBreakdownDocument, options); - } -export type GenerateThreadResponseBreakdownMutationHookResult = ReturnType; -export type GenerateThreadResponseBreakdownMutationResult = Apollo.MutationResult; -export type GenerateThreadResponseBreakdownMutationOptions = Apollo.BaseMutationOptions; export const GenerateThreadResponseAnswerDocument = gql` mutation GenerateThreadResponseAnswer($responseId: Int!) { generateThreadResponseAnswer(responseId: $responseId) { @@ -1185,4 +1222,110 @@ export function useAdjustThreadResponseChartMutation(baseOptions?: Apollo.Mutati } export type AdjustThreadResponseChartMutationHookResult = ReturnType; export type AdjustThreadResponseChartMutationResult = Apollo.MutationResult; -export type AdjustThreadResponseChartMutationOptions = Apollo.BaseMutationOptions; \ No newline at end of file +export type AdjustThreadResponseChartMutationOptions = Apollo.BaseMutationOptions; +export const AdjustmentTaskDocument = gql` + query AdjustmentTask($taskId: String!) { + adjustmentTask(taskId: $taskId) { + queryId + status + error { + code + shortMessage + message + stacktrace + } + sql + traceId + } +} + `; + +/** + * __useAdjustmentTaskQuery__ + * + * To run a query within a React component, call `useAdjustmentTaskQuery` and pass it any options that fit your needs. + * When your component renders, `useAdjustmentTaskQuery` returns an object from Apollo Client that contains loading, error, and data properties + * you can use to render your UI. + * + * @param baseOptions options that will be passed into the query, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options; + * + * @example + * const { data, loading, error } = useAdjustmentTaskQuery({ + * variables: { + * taskId: // value for 'taskId' + * }, + * }); + */ +export function useAdjustmentTaskQuery(baseOptions: Apollo.QueryHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useQuery(AdjustmentTaskDocument, options); + } +export function useAdjustmentTaskLazyQuery(baseOptions?: Apollo.LazyQueryHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useLazyQuery(AdjustmentTaskDocument, options); + } +export type AdjustmentTaskQueryHookResult = ReturnType; +export type AdjustmentTaskLazyQueryHookResult = ReturnType; +export type AdjustmentTaskQueryResult = Apollo.QueryResult; +export const CancelAdjustmentTaskDocument = gql` + mutation CancelAdjustmentTask($taskId: String!) { + cancelAdjustmentTask(taskId: $taskId) +} + `; +export type CancelAdjustmentTaskMutationFn = Apollo.MutationFunction; + +/** + * __useCancelAdjustmentTaskMutation__ + * + * To run a mutation, you first call `useCancelAdjustmentTaskMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useCancelAdjustmentTaskMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [cancelAdjustmentTaskMutation, { data, loading, error }] = useCancelAdjustmentTaskMutation({ + * variables: { + * taskId: // value for 'taskId' + * }, + * }); + */ +export function useCancelAdjustmentTaskMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(CancelAdjustmentTaskDocument, options); + } +export type CancelAdjustmentTaskMutationHookResult = ReturnType; +export type CancelAdjustmentTaskMutationResult = Apollo.MutationResult; +export type CancelAdjustmentTaskMutationOptions = Apollo.BaseMutationOptions; +export const RerunAdjustmentTaskDocument = gql` + mutation RerunAdjustmentTask($responseId: Int!) { + rerunAdjustmentTask(responseId: $responseId) +} + `; +export type RerunAdjustmentTaskMutationFn = Apollo.MutationFunction; + +/** + * __useRerunAdjustmentTaskMutation__ + * + * To run a mutation, you first call `useRerunAdjustmentTaskMutation` within a React component and pass it any options that fit your needs. + * When your component renders, `useRerunAdjustmentTaskMutation` returns a tuple that includes: + * - A mutate function that you can call at any time to execute the mutation + * - An object with fields that represent the current status of the mutation's execution + * + * @param baseOptions options that will be passed into the mutation, supported options are listed on: https://www.apollographql.com/docs/react/api/react-hooks/#options-2; + * + * @example + * const [rerunAdjustmentTaskMutation, { data, loading, error }] = useRerunAdjustmentTaskMutation({ + * variables: { + * responseId: // value for 'responseId' + * }, + * }); + */ +export function useRerunAdjustmentTaskMutation(baseOptions?: Apollo.MutationHookOptions) { + const options = {...defaultOptions, ...baseOptions} + return Apollo.useMutation(RerunAdjustmentTaskDocument, options); + } +export type RerunAdjustmentTaskMutationHookResult = ReturnType; +export type RerunAdjustmentTaskMutationResult = Apollo.MutationResult; +export type RerunAdjustmentTaskMutationOptions = Apollo.BaseMutationOptions; \ No newline at end of file diff --git a/wren-ui/src/apollo/client/graphql/home.ts b/wren-ui/src/apollo/client/graphql/home.ts index 62cfce97ec..ced17aac35 100644 --- a/wren-ui/src/apollo/client/graphql/home.ts +++ b/wren-ui/src/apollo/client/graphql/home.ts @@ -113,12 +113,26 @@ const COMMON_RESPONSE = gql` askingTask { ...CommonAskingTask } + adjustment { + type + payload + } + adjustmentTask { + queryId + status + error { + ...CommonError + } + sql + traceId + } } ${COMMON_BREAKDOWN_DETAIL} ${COMMON_ANSWER_DETAIL} ${COMMON_CHART_DETAIL} ${COMMON_ASKING_TASK} + ${COMMON_ERROR} `; const COMMON_RECOMMENDED_QUESTIONS_TASK = gql` @@ -253,6 +267,19 @@ export const UPDATE_THREAD_RESPONSE = gql` ${COMMON_RESPONSE} `; +// For adjust reasoning steps or SQL +export const ADJUST_THREAD_RESPONSE = gql` + mutation AdjustThreadResponse( + $responseId: Int! + $data: AdjustThreadResponseInput! + ) { + adjustThreadResponse(responseId: $responseId, data: $data) { + ...CommonResponse + } + } + ${COMMON_RESPONSE} +`; + export const DELETE_THREAD = gql` mutation DeleteThread($where: ThreadUniqueWhereInput!) { deleteThread(where: $where) @@ -329,16 +356,6 @@ export const GENERATE_THREAD_RECOMMENDATION_QUESTIONS = gql` } `; -export const GENERATE_THREAD_RESPONSE_BREAKDOWN = gql` - mutation GenerateThreadResponseBreakdown($responseId: Int!) { - generateThreadResponseBreakdown(responseId: $responseId) { - ...CommonResponse - } - } - - ${COMMON_RESPONSE} -`; - export const GENERATE_THREAD_RESPONSE_ANSWER = gql` mutation GenerateThreadResponseAnswer($responseId: Int!) { generateThreadResponseAnswer(responseId: $responseId) { @@ -369,3 +386,32 @@ export const ADJUST_THREAD_RESPONSE_CHART = gql` } ${COMMON_RESPONSE} `; + +export const ADJUSTMENT_TASK = gql` + query AdjustmentTask($taskId: String!) { + adjustmentTask(taskId: $taskId) { + queryId + status + error { + code + shortMessage + message + stacktrace + } + sql + traceId + } + } +`; + +export const CANCEL_ADJUSTMENT_TASK = gql` + mutation CancelAdjustmentTask($taskId: String!) { + cancelAdjustmentTask(taskId: $taskId) + } +`; + +export const RERUN_ADJUSTMENT_TASK = gql` + mutation RerunAdjustmentTask($responseId: Int!) { + rerunAdjustmentTask(responseId: $responseId) + } +`; diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index 453ace9678..e0caa6503b 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -30,6 +30,9 @@ import { GenerateInstructionInput, InstructionStatus, InstructionResult, + AskFeedbackInput, + AskFeedbackResult, + AskFeedbackStatus, } from '@server/models/adaptor'; import { getLogger } from '@server/utils'; import * as Errors from '@server/utils/error'; @@ -119,6 +122,13 @@ export interface IWrenAIAdaptor { ): Promise; getInstructionResult(queryId: string): Promise; deleteInstructions(ids: number[], projectId: number): Promise; + + /** + * Ask feedback APIs + */ + createAskFeedback(input: AskFeedbackInput): Promise; + getAskFeedbackResult(queryId: string): Promise; + cancelAskFeedback(queryId: string): Promise; } export class WrenAIAdaptor implements IWrenAIAdaptor { @@ -647,6 +657,77 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } } + public async createAskFeedback( + input: AskFeedbackInput, + ): Promise { + try { + const body = { + tables: input.tables, + sql_generation_reasoning: input.sqlGenerationReasoning, + sql: input.sql, + project_id: input.projectId.toString(), + configurations: input.configurations, + }; + + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks`, + body, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug( + `Got error when creating ask feedback: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async getAskFeedbackResult( + queryId: string, + ): Promise { + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`, + ); + return this.transformAskFeedbackResult(res.data); + } catch (err: any) { + logger.debug( + `Got error when getting ask feedback result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async cancelAskFeedback(queryId: string): Promise { + try { + await axios.patch( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`, + { + status: 'stopped', + }, + ); + } catch (err: any) { + logger.debug( + `Got error when canceling ask feedback: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + private transformAskFeedbackResult(body: any): AskFeedbackResult { + const { status, error } = this.transformStatusAndError(body); + return { + status: status as AskFeedbackStatus, + error, + response: + body.response?.map((result: any) => ({ + sql: result.sql, + type: result.type?.toUpperCase() as AskCandidateType, + })) || [], + traceId: body.trace_id, + }; + } + private transformChartAdjustmentInput(input: ChartAdjustmentInput) { const { query, sql, adjustmentOption, chartSchema, configurations } = input; return { @@ -794,7 +875,8 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { | ChartStatus | SqlPairStatus | QuestionsStatus - | InstructionStatus; + | InstructionStatus + | AskFeedbackStatus; error?: { code: Errors.GeneralErrorCodes; message: string; diff --git a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts index c946c9b7a0..f849cffb88 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts @@ -200,7 +200,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { return res.data as EngineQueryResponse; } catch (err: any) { logger.debug(`Got error when querying duckdb: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } @@ -219,7 +222,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { await axios.patch(url.href, configPayload, { headers }); } catch (err: any) { logger.debug(`Got error when patching config: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } @@ -248,7 +254,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { return res.data; } catch (err: any) { logger.debug(`Got error when previewing data: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } diff --git a/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts b/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts new file mode 100644 index 0000000000..6db81fa906 --- /dev/null +++ b/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts @@ -0,0 +1,502 @@ +import { getLogger } from '@server/utils'; +import { + AskFeedbackInput, + AskFeedbackResult, + AskFeedbackStatus, +} from '@server/models/adaptor'; +import { + AskingTask, + IAskingTaskRepository, + IThreadResponseRepository, + ThreadResponse, + ThreadResponseAdjustmentType, +} from '@server/repositories'; +import { IWrenAIAdaptor } from '../adaptors'; +import { TelemetryEvent, WrenService } from '../telemetry/telemetry'; +import { PostHogTelemetry } from '../telemetry/telemetry'; + +const logger = getLogger('AdjustmentTaskTracker'); +logger.level = 'debug'; + +interface TrackedTask { + queryId: string; + taskId?: number; + lastPolled: number; + result?: AskFeedbackResult; + isFinalized: boolean; + threadResponseId: number; + question: string; + originalThreadResponseId: number; + rerun?: boolean; + adjustmentPayload?: { + originalThreadResponseId: number; + retrievedTables: string[]; + sqlGenerationReasoning: string; + }; +} + +export type TrackedAdjustmentResult = AskFeedbackResult & { + taskId?: number; + queryId: string; +}; + +export type CreateAdjustmentTaskInput = AskFeedbackInput & { + threadId: number; + question: string; + originalThreadResponseId: number; + configurations: { language: string }; +}; + +export type RerunAdjustmentTaskInput = { + threadResponseId: number; + threadId: number; + projectId: number; + configurations: { language: string }; +}; + +export interface IAdjustmentBackgroundTaskTracker { + createAdjustmentTask( + input: CreateAdjustmentTaskInput, + ): Promise<{ queryId: string }>; + getAdjustmentResult(queryId: string): Promise; + getAdjustmentResultById(id: number): Promise; + cancelAdjustmentTask(queryId: string): Promise; + rerunAdjustmentTask( + input: RerunAdjustmentTaskInput, + ): Promise<{ queryId: string }>; +} + +export class AdjustmentBackgroundTaskTracker + implements IAdjustmentBackgroundTaskTracker +{ + private wrenAIAdaptor: IWrenAIAdaptor; + private askingTaskRepository: IAskingTaskRepository; + private trackedTasks: Map = new Map(); + private trackedTasksById: Map = new Map(); + private pollingInterval: number; + private memoryRetentionTime: number; + private pollingIntervalId: NodeJS.Timeout; + private runningJobs = new Set(); + private threadResponseRepository: IThreadResponseRepository; + private telemetry: PostHogTelemetry; + + constructor({ + telemetry, + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + pollingInterval = 1000, // 1 second + memoryRetentionTime = 5 * 60 * 1000, // 5 minutes + }: { + telemetry: PostHogTelemetry; + wrenAIAdaptor: IWrenAIAdaptor; + askingTaskRepository: IAskingTaskRepository; + threadResponseRepository: IThreadResponseRepository; + pollingInterval?: number; + memoryRetentionTime?: number; + }) { + this.telemetry = telemetry; + this.wrenAIAdaptor = wrenAIAdaptor; + this.askingTaskRepository = askingTaskRepository; + this.threadResponseRepository = threadResponseRepository; + this.pollingInterval = pollingInterval; + this.memoryRetentionTime = memoryRetentionTime; + this.startPolling(); + } + + public async createAdjustmentTask( + input: CreateAdjustmentTaskInput, + ): Promise<{ queryId: string; createdThreadResponse: ThreadResponse }> { + try { + // Call the AI service to create a task + const response = await this.wrenAIAdaptor.createAskFeedback(input); + const queryId = response.queryId; + + // create a new asking task + const createdAskingTask = await this.askingTaskRepository.createOne({ + queryId, + question: input.question, + threadId: input.threadId, + detail: { + adjustment: true, + status: AskFeedbackStatus.UNDERSTANDING, + response: [], + error: null, + }, + }); + + // create a new thread response with adjustment payload + const createdThreadResponse = + await this.threadResponseRepository.createOne({ + question: input.question, + threadId: input.threadId, + askingTaskId: createdAskingTask.id, + adjustment: { + type: ThreadResponseAdjustmentType.REASONING, + payload: { + originalThreadResponseId: input.originalThreadResponseId, + retrievedTables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + }, + }, + }); + + // bind the thread response to the asking task + // todo: it's weird that we need to update the asking task again + // find a better way to do this + await this.askingTaskRepository.updateOne(createdAskingTask.id, { + threadResponseId: createdThreadResponse.id, + }); + + // Start tracking this task + const task = { + queryId, + lastPolled: Date.now(), + isFinalized: false, + originalThreadResponseId: input.originalThreadResponseId, + threadResponseId: createdThreadResponse.id, + question: input.question, + adjustmentPayload: { + originalThreadResponseId: input.originalThreadResponseId, + retrievedTables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + }, + } as TrackedTask; + this.trackedTasks.set(queryId, task); + this.trackedTasksById.set(createdThreadResponse.id, task); + + logger.info(`Created adjustment task with queryId: ${queryId}`); + return { queryId, createdThreadResponse }; + } catch (err) { + logger.error(`Failed to create adjustment task: ${err}`); + throw err; + } + } + + public async rerunAdjustmentTask( + input: RerunAdjustmentTaskInput, + ): Promise<{ queryId: string }> { + const currentThreadResponse = await this.threadResponseRepository.findOneBy( + { + id: input.threadResponseId, + }, + ); + if (!currentThreadResponse) { + throw new Error(`Thread response ${input.threadResponseId} not found`); + } + + const adjustment = currentThreadResponse.adjustment; + if (!adjustment) { + throw new Error( + `Thread response ${input.threadResponseId} has no adjustment`, + ); + } + + const originalThreadResponse = + await this.threadResponseRepository.findOneBy({ + id: adjustment.payload?.originalThreadResponseId, + }); + if (!originalThreadResponse) { + throw new Error( + `Original thread response ${adjustment.payload?.originalThreadResponseId} not found`, + ); + } + + // call createAskFeedback on AI service + const response = await this.wrenAIAdaptor.createAskFeedback({ + ...input, + tables: adjustment.payload?.retrievedTables, + sqlGenerationReasoning: adjustment.payload?.sqlGenerationReasoning, + sql: originalThreadResponse.sql, + }); + const queryId = response.queryId; + + // update asking task with new queryId + await this.askingTaskRepository.updateOne( + currentThreadResponse.askingTaskId, + { + queryId, + + // reset detail + detail: { + adjustment: true, + status: AskFeedbackStatus.UNDERSTANDING, + response: [], + error: null, + }, + }, + ); + + // schedule task + const task = { + queryId, + lastPolled: Date.now(), + isFinalized: false, + originalThreadResponseId: originalThreadResponse.id, + threadResponseId: currentThreadResponse.id, + question: originalThreadResponse.question, + rerun: true, + adjustmentPayload: { + originalThreadResponseId: originalThreadResponse.id, + retrievedTables: adjustment.payload?.retrievedTables, + sqlGenerationReasoning: adjustment.payload?.sqlGenerationReasoning, + }, + } as TrackedTask; + this.trackedTasks.set(queryId, task); + this.trackedTasksById.set(currentThreadResponse.id, task); + + logger.info(`Rerun adjustment task with queryId: ${queryId}`); + return { queryId }; + } + + public async getAdjustmentResult( + queryId: string, + ): Promise { + // Check if we're tracking this task in memory + const trackedTask = this.trackedTasks.get(queryId); + + if (trackedTask && trackedTask.result) { + return { + ...trackedTask.result, + queryId, + taskId: trackedTask.taskId, + }; + } + + // If not in memory or no result yet, check the database + return this.getAdjustmentResultFromDB({ queryId }); + } + + public async getAdjustmentResultById( + id: number, + ): Promise { + const task = this.trackedTasksById.get(id); + if (task) { + return this.getAdjustmentResult(task.queryId); + } + + return this.getAdjustmentResultFromDB({ taskId: id }); + } + + public async cancelAdjustmentTask(queryId: string): Promise { + await this.wrenAIAdaptor.cancelAskFeedback(queryId); + + // telemetry + const eventName = TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_CANCEL; + this.telemetry.sendEvent(eventName, { + queryId, + }); + } + + public stopPolling(): void { + if (this.pollingIntervalId) { + clearInterval(this.pollingIntervalId); + } + } + + private startPolling(): void { + this.pollingIntervalId = setInterval(() => { + this.pollTasks(); + }, this.pollingInterval); + } + + private async pollTasks(): Promise { + const now = Date.now(); + const tasksToRemove: string[] = []; + + // Create an array of job functions + const jobs = Array.from(this.trackedTasks.entries()).map( + ([queryId, task]) => + async () => { + try { + // Skip if the job is already running + if (this.runningJobs.has(queryId)) { + return; + } + + // Skip finalized tasks that have been in memory too long + if ( + task.isFinalized && + now - task.lastPolled > this.memoryRetentionTime + ) { + tasksToRemove.push(queryId); + return; + } + + // Skip finalized tasks + if (task.isFinalized) { + return; + } + + // Mark the job as running + this.runningJobs.add(queryId); + + // Poll for updates + logger.info(`Polling for updates for task ${queryId}`); + const result = + await this.wrenAIAdaptor.getAskFeedbackResult(queryId); + task.lastPolled = now; + + // if result is not changed, we don't need to update the database + if (!this.isResultChanged(task.result, result)) { + this.runningJobs.delete(queryId); + return; + } + + // update task in memory if any change + task.result = result; + + // update the database + logger.info(`Updating task ${queryId} in database`); + await this.updateTaskInDatabase({ queryId }, task); + + // Check if task is now finalized + if (this.isTaskFinalized(result.status)) { + task.isFinalized = true; + // update thread response if threadResponseId is provided + if (task.threadResponseId) { + await this.updateThreadResponseWhenTaskFinalized(task); + } + + // telemetry + const eventName = task.rerun + ? TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_RERUN + : TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE; + const eventProperties = { + taskId: task.taskId, + queryId: task.queryId, + status: result.status, + error: result.error, + adjustmentPayload: task.adjustmentPayload, + }; + if (result.status === AskFeedbackStatus.FINISHED) { + this.telemetry.sendEvent(eventName, eventProperties); + } else { + this.telemetry.sendEvent( + eventName, + eventProperties, + WrenService.AI, + false, + ); + } + + logger.info( + `Task ${queryId} is finalized with status: ${result.status}`, + ); + } + + // Mark the job as finished + this.runningJobs.delete(queryId); + } catch (err) { + this.runningJobs.delete(queryId); + logger.error(err.stack); + throw err; + } + }, + ); + + // Run all jobs in parallel + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // Log any rejected promises + results.forEach((result, index) => { + if (result.status === 'rejected') { + logger.error(`Job ${index} failed: ${result.reason}`); + } + }); + + // Clean up tasks that have been in memory too long + if (tasksToRemove.length > 0) { + logger.info( + `Cleaning up tasks that have been in memory too long. Tasks: ${tasksToRemove.join( + ', ', + )}`, + ); + } + for (const queryId of tasksToRemove) { + this.trackedTasks.delete(queryId); + } + }); + } + + private async updateThreadResponseWhenTaskFinalized( + task: TrackedTask, + ): Promise { + const response = task?.result?.response?.[0]; + if (!response) { + return; + } + await this.threadResponseRepository.updateOne(task.threadResponseId, { + sql: response?.sql, + }); + } + + private async getAdjustmentResultFromDB({ + queryId, + taskId, + }: { + queryId?: string; + taskId?: number; + }): Promise { + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + return null; + } + + return { + ...(taskRecord?.detail as AskFeedbackResult), + queryId: queryId || taskRecord?.queryId, + taskId: taskRecord?.id, + }; + } + + private async updateTaskInDatabase( + filter: { queryId?: string; taskId?: number }, + trackedTask: TrackedTask, + ): Promise { + const { queryId, taskId } = filter; + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + throw new Error('Asking task not found'); + } + + // update the task + await this.askingTaskRepository.updateOne(taskRecord.id, { + detail: { + adjustment: true, + ...trackedTask.result, + }, + }); + } + + private isTaskFinalized(status: AskFeedbackStatus): boolean { + return [ + AskFeedbackStatus.FINISHED, + AskFeedbackStatus.FAILED, + AskFeedbackStatus.STOPPED, + ].includes(status); + } + + private isResultChanged( + previousResult: AskFeedbackResult, + newResult: AskFeedbackResult, + ): boolean { + // check status change + if (previousResult?.status !== newResult.status) { + return true; + } + + return false; + } +} diff --git a/wren-ui/src/apollo/server/backgrounds/index.ts b/wren-ui/src/apollo/server/backgrounds/index.ts index 3b3438be62..d6e6ef8131 100644 --- a/wren-ui/src/apollo/server/backgrounds/index.ts +++ b/wren-ui/src/apollo/server/backgrounds/index.ts @@ -1,2 +1,3 @@ export * from './recommend-question'; export * from './chart'; +export * from './adjustmentBackgroundTracker'; diff --git a/wren-ui/src/apollo/server/models/adaptor.ts b/wren-ui/src/apollo/server/models/adaptor.ts index b75efd649f..0af534fb53 100644 --- a/wren-ui/src/apollo/server/models/adaptor.ts +++ b/wren-ui/src/apollo/server/models/adaptor.ts @@ -291,3 +291,31 @@ export interface InstructionResult { status: InstructionStatus; error?: WrenAIError; } + +// ask feedback +export interface AskFeedbackInput { + tables: string[]; + sqlGenerationReasoning: string; + sql: string; + projectId: number; + configurations?: ProjectConfigurations; +} + +export enum AskFeedbackStatus { + UNDERSTANDING = 'UNDERSTANDING', + GENERATING = 'GENERATING', + CORRECTING = 'CORRECTING', + FINISHED = 'FINISHED', + FAILED = 'FAILED', + STOPPED = 'STOPPED', +} + +export interface AskFeedbackResult { + status: AskFeedbackStatus; + error?: WrenAIError; + response: Array<{ + type: AskCandidateType.LLM; + sql: string; + }>; + traceId?: string; +} diff --git a/wren-ui/src/apollo/server/models/model.ts b/wren-ui/src/apollo/server/models/model.ts index afb1029f79..fd763d3de6 100644 --- a/wren-ui/src/apollo/server/models/model.ts +++ b/wren-ui/src/apollo/server/models/model.ts @@ -91,7 +91,7 @@ export interface CheckCalculatedFieldCanQueryData { export interface PreviewSQLData { sql: string; - projectId?: number; + projectId?: string; limit?: number; dryRun?: boolean; } diff --git a/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts index 889bbcf5d0..942f445ee1 100644 --- a/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts +++ b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts @@ -7,13 +7,19 @@ import { mapValues, snakeCase, } from 'lodash'; -import { AskResult } from '../models/adaptor'; +import { AskFeedbackResult, AskResult } from '../models/adaptor'; + +export type AskingTaskDetail = + | AskResult + | (AskFeedbackResult & { + adjustment?: boolean; + }); export interface AskingTask { id: number; queryId: string; question?: string; - detail?: AskResult; + detail?: AskingTaskDetail; threadId?: number; threadResponseId?: number; createdAt: Date; diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index 87634989ac..f51d1eda91 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -39,6 +39,29 @@ export interface ThreadResponseChartDetail { adjustment?: boolean; } +export enum ThreadResponseAdjustmentType { + REASONING = 'REASONING', + APPLY_SQL = 'APPLY_SQL', +} + +export type ThreadResponseAdjustmentReasoningPayload = { + originalThreadResponseId?: number; + retrievedTables?: string[]; + sqlGenerationReasoning?: string; +}; + +export type ThreadResponseAdjustmentApplySqlPayload = { + originalThreadResponseId?: number; + sql?: string; +}; + +export interface ThreadResponseAdjustment { + type: ThreadResponseAdjustmentType; + // todo: I think we could use a better way to do this instead of using a union type + payload: ThreadResponseAdjustmentReasoningPayload & + ThreadResponseAdjustmentApplySqlPayload; +} + export interface ThreadResponse { id: number; // ID askingTaskId?: number; // Reference to asking_task.id @@ -49,6 +72,7 @@ export interface ThreadResponse { answerDetail?: ThreadResponseAnswerDetail; // AI generated text-based answer detail breakdownDetail?: ThreadResponseBreakdownDetail; // Thread response breakdown detail chartDetail?: ThreadResponseChartDetail; // Thread response chart detail + adjustment?: ThreadResponseAdjustment; // Thread response adjustment } export interface IThreadResponseRepository @@ -67,6 +91,7 @@ export class ThreadResponseRepository 'answerDetail', 'breakdownDetail', 'chartDetail', + 'adjustment', ]; constructor(knexPg: Knex) { @@ -102,11 +127,16 @@ export class ThreadResponseRepository res.chartDetail && typeof res.chartDetail === 'string' ? JSON.parse(res.chartDetail) : res.chartDetail; + const adjustment = + res.adjustment && typeof res.adjustment === 'string' + ? JSON.parse(res.adjustment) + : res.adjustment; return { ...res, answerDetail: answerDetail || null, breakdownDetail: breakdownDetail || null, chartDetail: chartDetail || null, + adjustment: adjustment || null, }; }) as ThreadResponse[]; } @@ -120,6 +150,7 @@ export class ThreadResponseRepository answerDetail: ThreadResponseAnswerDetail; breakdownDetail: ThreadResponseBreakdownDetail; chartDetail: ThreadResponseChartDetail; + adjustment: ThreadResponseAdjustment; }>, queryOptions?: IQueryOptions, ) { @@ -136,6 +167,7 @@ export class ThreadResponseRepository chartDetail: data.chartDetail ? JSON.stringify(data.chartDetail) : undefined, + adjustment: data.adjustment ? JSON.stringify(data.adjustment) : undefined, }; const executer = queryOptions?.tx ? queryOptions.tx : this.knex; const [result] = await executer(this.tableName) diff --git a/wren-ui/src/apollo/server/resolvers.ts b/wren-ui/src/apollo/server/resolvers.ts index 2e91bf22ac..fc01d77279 100644 --- a/wren-ui/src/apollo/server/resolvers.ts +++ b/wren-ui/src/apollo/server/resolvers.ts @@ -34,6 +34,9 @@ const resolvers = { suggestedQuestions: askingResolver.getSuggestedQuestions, instantRecommendedQuestions: askingResolver.getInstantRecommendedQuestions, + // Adjustment + adjustmentTask: askingResolver.getAdjustmentTask, + // Thread thread: askingResolver.getThread, threads: askingResolver.listThreads, @@ -97,6 +100,11 @@ const resolvers = { askingResolver.createInstantRecommendedQuestions, rerunAskingTask: askingResolver.rerunAskingTask, + // Adjustment + adjustThreadResponse: askingResolver.adjustThreadResponse, + cancelAdjustmentTask: askingResolver.cancelAdjustThreadResponseAnswer, + rerunAdjustmentTask: askingResolver.rerunAdjustThreadResponseAnswer, + // Thread createThread: askingResolver.createThread, updateThread: askingResolver.updateThread, diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index ecd509f016..b1fa3f9a97 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -5,6 +5,7 @@ import { AskResultType, RecommendationQuestionStatus, ChartAdjustmentOption, + AskFeedbackStatus, } from '@server/models/adaptor'; import { Thread } from '../repositories/threadRepository'; import { @@ -39,6 +40,14 @@ export interface Task { id: string; } +export interface AdjustmentTask { + queryId: string; + status: AskFeedbackStatus; + error: WrenAIError | null; + sql: string; + traceId: string; +} + export interface AskingTask { type: AskResultType | null; status: AskResultStatus; @@ -108,6 +117,13 @@ export class AskingResolver { this.generateThreadResponseChart.bind(this); this.adjustThreadResponseChart = this.adjustThreadResponseChart.bind(this); this.transformAskingTask = this.transformAskingTask.bind(this); + + this.adjustThreadResponse = this.adjustThreadResponse.bind(this); + this.cancelAdjustThreadResponseAnswer = + this.cancelAdjustThreadResponseAnswer.bind(this); + this.rerunAdjustThreadResponseAnswer = + this.rerunAdjustThreadResponseAnswer.bind(this); + this.getAdjustmentTask = this.getAdjustmentTask.bind(this); } public async generateProjectRecommendationQuestions( @@ -304,6 +320,7 @@ export class AskingResolver { breakdownDetail: response.breakdownDetail, answerDetail: response.answerDetail, chartDetail: response.chartDetail, + adjustment: response.adjustment, }); return acc; @@ -448,6 +465,98 @@ export class AskingResolver { return task; } + public async adjustThreadResponse( + _root: any, + args: { + responseId: number; + data: { + tables?: string[]; + sqlGenerationReasoning?: string; + sql?: string; + }; + }, + ctx: IContext, + ): Promise { + const { responseId, data } = args; + const askingService = ctx.askingService; + const project = await ctx.projectService.getCurrentProject(); + + if (data.sql) { + const response = await askingService.adjustThreadResponseWithSQL( + responseId, + { + sql: data.sql, + }, + ); + ctx.telemetry.sendEvent( + TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_WITH_SQL, + { + sql: data.sql, + responseId, + }, + ); + return response; + } + + return askingService.adjustThreadResponseAnswer( + responseId, + { + projectId: project.id, + tables: data.tables, + sqlGenerationReasoning: data.sqlGenerationReasoning, + }, + { + language: WrenAILanguage[project.language] || WrenAILanguage.EN, + }, + ); + } + + public async cancelAdjustThreadResponseAnswer( + _root: any, + args: { taskId: string }, + ctx: IContext, + ): Promise { + const { taskId } = args; + const askingService = ctx.askingService; + await askingService.cancelAdjustThreadResponseAnswer(taskId); + return true; + } + + public async rerunAdjustThreadResponseAnswer( + _root: any, + args: { responseId: number }, + ctx: IContext, + ): Promise { + const { responseId } = args; + const askingService = ctx.askingService; + const project = await ctx.projectService.getCurrentProject(); + await askingService.rerunAdjustThreadResponseAnswer( + responseId, + project.id, + { + language: WrenAILanguage[project.language] || WrenAILanguage.EN, + }, + ); + return true; + } + + public async getAdjustmentTask( + _root: any, + args: { taskId: string }, + ctx: IContext, + ): Promise { + const { taskId } = args; + const askingService = ctx.askingService; + const adjustmentTask = await askingService.getAdjustmentTask(taskId); + return { + queryId: adjustmentTask?.queryId, + status: adjustmentTask?.status, + error: adjustmentTask?.error, + sql: adjustmentTask?.response?.[0]?.sql, + traceId: adjustmentTask?.traceId, + }; + } + public async generateThreadResponseBreakdown( _root: any, args: { responseId: number }, @@ -604,6 +713,9 @@ export class AskingResolver { return parent.sql ? format(parent.sql) : null; }, askingTask: async (parent: ThreadResponse, _args: any, ctx: IContext) => { + if (parent.adjustment) { + return null; + } const askingService = ctx.askingService; const askingTask = await askingService.getAskingTaskById( parent.askingTaskId, @@ -611,6 +723,27 @@ export class AskingResolver { if (!askingTask) return null; return this.transformAskingTask(askingTask, ctx); }, + adjustmentTask: async ( + parent: ThreadResponse, + _args: any, + ctx: IContext, + ): Promise => { + if (!parent.adjustment) { + return null; + } + const askingService = ctx.askingService; + const adjustmentTask = await askingService.getAdjustmentTaskById( + parent.askingTaskId, + ); + if (!adjustmentTask) return null; + return { + queryId: adjustmentTask?.queryId, + status: adjustmentTask?.status, + error: adjustmentTask?.error, + sql: adjustmentTask?.response?.[0]?.sql, + traceId: adjustmentTask?.traceId, + }; + }, }); public getDetailStepNestedResolver = () => ({ diff --git a/wren-ui/src/apollo/server/resolvers/modelResolver.ts b/wren-ui/src/apollo/server/resolvers/modelResolver.ts index 5711d70765..c80b511690 100644 --- a/wren-ui/src/apollo/server/resolvers/modelResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/modelResolver.ts @@ -932,7 +932,7 @@ export class ModelResolver { ) { const { sql, projectId, limit, dryRun } = args.data; const project = projectId - ? await ctx.projectService.getProjectById(projectId) + ? await ctx.projectService.getProjectById(parseInt(projectId)) : await ctx.projectService.getCurrentProject(); const { manifest } = await ctx.deployService.getLastDeployment(project.id); return await ctx.queryService.preview(sql, { diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 5e9012723d..2a7835d4f2 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -656,6 +656,12 @@ export const typeDefs = gql` theta: String } + input AdjustThreadResponseInput { + tables: [String!] + sqlGenerationReasoning: String + sql: String + } + input PreviewDataInput { responseId: Int! # Optional, only used for preview data of a single step @@ -707,6 +713,24 @@ export const typeDefs = gql` adjustment: Boolean } + enum ThreadResponseAdjustmentType { + REASONING + APPLY_SQL + } + + type ThreadResponseAdjustment { + type: ThreadResponseAdjustmentType! + payload: JSON + } + + type AdjustmentTask { + queryId: String + status: AskingTaskStatus + error: Error + sql: String + traceId: String + } + type ThreadResponse { id: Int! threadId: Int! @@ -717,6 +741,8 @@ export const typeDefs = gql` answerDetail: ThreadResponseAnswerDetail chartDetail: ThreadResponseChartDetail askingTask: AskingTask + adjustment: ThreadResponseAdjustment + adjustmentTask: AdjustmentTask } # Thread only consists of basic information of a thread @@ -762,7 +788,7 @@ export const typeDefs = gql` input PreviewSQLDataInput { sql: String! - projectId: Int + projectId: String limit: Int dryRun: Boolean } @@ -955,6 +981,9 @@ export const typeDefs = gql` threadResponse(responseId: Int!): ThreadResponse! nativeSql(responseId: Int!): String! + # Adjustment + adjustmentTask(taskId: String!): AdjustmentTask + # Settings settings: Settings! @@ -1066,6 +1095,14 @@ export const typeDefs = gql` data: AdjustThreadResponseChartInput! ): ThreadResponse! + # Adjustment + adjustThreadResponse( + responseId: Int! + data: AdjustThreadResponseInput! + ): ThreadResponse! + cancelAdjustmentTask(taskId: String!): Boolean! + rerunAdjustmentTask(responseId: Int!): Boolean! + # Settings resetCurrentProject: Boolean! updateCurrentProject(data: UpdateCurrentProjectInput!): Boolean! diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 78eed90901..055caec9e0 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -16,6 +16,7 @@ import { IThreadRepository, Thread } from '../repositories/threadRepository'; import { IThreadResponseRepository, ThreadResponse, + ThreadResponseAdjustmentType, } from '../repositories/threadResponseRepository'; import { getLogger } from '@server/utils'; import { isEmpty, isNil } from 'lodash'; @@ -25,13 +26,19 @@ import { TelemetryEvent, WrenService, } from '../telemetry/telemetry'; -import { IViewRepository, Project } from '../repositories'; +import { + IAskingTaskRepository, + IViewRepository, + Project, +} from '../repositories'; import { IQueryService, PreviewDataResponse } from './queryService'; import { IMDLService } from './mdlService'; import { ThreadRecommendQuestionBackgroundTracker, ChartBackgroundTracker, ChartAdjustmentBackgroundTracker, + AdjustmentBackgroundTaskTracker, + TrackedAdjustmentResult, } from '../backgrounds'; import { getConfig } from '@server/config'; import { TextBasedAnswerBackgroundTracker } from '../backgrounds/textBasedAnswerBackgroundTracker'; @@ -94,6 +101,17 @@ export enum ThreadResponseAnswerStatus { INTERRUPTED = 'INTERRUPTED', } +// adjustment input +export interface AdjustmentReasoningInput { + tables: string[]; + sqlGenerationReasoning: string; + projectId: number; +} + +export interface AdjustmentSqlInput { + sql: string; +} + export interface IAskingService { /** * Asking task. @@ -155,6 +173,23 @@ export interface IAskingService { input: ChartAdjustmentOption, configurations: { language: string }, ): Promise; + adjustThreadResponseWithSQL( + threadResponseId: number, + input: AdjustmentSqlInput, + ): Promise; + adjustThreadResponseAnswer( + threadResponseId: number, + input: AdjustmentReasoningInput, + configurations: { language: string }, + ): Promise; + cancelAdjustThreadResponseAnswer(taskId: string): Promise; + rerunAdjustThreadResponseAnswer( + threadResponseId: number, + projectId: number, + configurations: { language: string }, + ): Promise<{ queryId: string }>; + getAdjustmentTask(taskId: string): Promise; + getAdjustmentTaskById(id: number): Promise; changeThreadResponseAnswerDetailStatus( responseId: number, status: ThreadResponseAnswerStatus, @@ -379,6 +414,8 @@ export class AskingService implements IAskingService { private telemetry: PostHogTelemetry; private mdlService: IMDLService; private askingTaskTracker: IAskingTaskTracker; + private askingTaskRepository: IAskingTaskRepository; + private adjustmentBackgroundTracker: AdjustmentBackgroundTaskTracker; constructor({ telemetry, @@ -388,6 +425,7 @@ export class AskingService implements IAskingService { viewRepository, threadRepository, threadResponseRepository, + askingTaskRepository, queryService, mdlService, askingTaskTracker, @@ -399,6 +437,7 @@ export class AskingService implements IAskingService { viewRepository: IViewRepository; threadRepository: IThreadRepository; threadResponseRepository: IThreadResponseRepository; + askingTaskRepository: IAskingTaskRepository; queryService: IQueryService; mdlService: IMDLService; askingTaskTracker: IAskingTaskTracker; @@ -441,7 +480,14 @@ export class AskingService implements IAskingService { wrenAIAdaptor, threadRepository, }); + this.adjustmentBackgroundTracker = new AdjustmentBackgroundTaskTracker({ + telemetry, + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + }); + this.askingTaskRepository = askingTaskRepository; this.mdlService = mdlService; this.askingTaskTracker = askingTaskTracker; } @@ -1013,6 +1059,97 @@ export class AskingService implements IAskingService { return lastDeploy.hash; } + public async adjustThreadResponseWithSQL( + threadResponseId: number, + input: AdjustmentSqlInput, + ): Promise { + const response = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!response) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + return await this.threadResponseRepository.createOne({ + sql: input.sql, + threadId: response.threadId, + question: response.question, + adjustment: { + type: ThreadResponseAdjustmentType.APPLY_SQL, + payload: { + originalThreadResponseId: response.id, + sql: input.sql, + }, + }, + }); + } + + public async adjustThreadResponseAnswer( + threadResponseId: number, + input: AdjustmentReasoningInput, + configurations: { language: string }, + ): Promise { + const originalThreadResponse = + await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!originalThreadResponse) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + const { createdThreadResponse } = + await this.adjustmentBackgroundTracker.createAdjustmentTask({ + threadId: originalThreadResponse.threadId, + tables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + sql: originalThreadResponse.sql, + projectId: input.projectId, + configurations, + question: originalThreadResponse.question, + originalThreadResponseId: originalThreadResponse.id, + }); + return createdThreadResponse; + } + + public async cancelAdjustThreadResponseAnswer(taskId: string): Promise { + // call cancelAskFeedback on AI service + await this.adjustmentBackgroundTracker.cancelAdjustmentTask(taskId); + } + + public async rerunAdjustThreadResponseAnswer( + threadResponseId: number, + projectId: number, + configurations: { language: string }, + ): Promise<{ queryId: string }> { + const threadResponse = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!threadResponse) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + const { queryId } = + await this.adjustmentBackgroundTracker.rerunAdjustmentTask({ + threadId: threadResponse.threadId, + threadResponseId, + projectId, + configurations, + }); + return { queryId }; + } + + public async getAdjustmentTask( + taskId: string, + ): Promise { + return this.adjustmentBackgroundTracker.getAdjustmentResult(taskId); + } + + public async getAdjustmentTaskById( + id: number, + ): Promise { + return this.adjustmentBackgroundTracker.getAdjustmentResultById(id); + } + /** * Get the thread response of a thread for asking * @param threadId diff --git a/wren-ui/src/apollo/server/services/askingTaskTracker.ts b/wren-ui/src/apollo/server/services/askingTaskTracker.ts index 2ed0274abe..a2f18a61f0 100644 --- a/wren-ui/src/apollo/server/services/askingTaskTracker.ts +++ b/wren-ui/src/apollo/server/services/askingTaskTracker.ts @@ -404,7 +404,7 @@ export class AskingTaskTracker implements IAskingTaskTracker { } return { - ...taskRecord?.detail, + ...(taskRecord?.detail as AskResult), queryId: queryId || taskRecord?.queryId, question: taskRecord?.question, taskId: taskRecord?.id, diff --git a/wren-ui/src/apollo/server/telemetry/telemetry.ts b/wren-ui/src/apollo/server/telemetry/telemetry.ts index 99c469d54a..930982013b 100644 --- a/wren-ui/src/apollo/server/telemetry/telemetry.ts +++ b/wren-ui/src/apollo/server/telemetry/telemetry.ts @@ -56,6 +56,12 @@ export enum TelemetryEvent { HOME_GENERATE_PROJECT_RECOMMENDATION_QUESTIONS = 'home_generate_project_recommendation_questions', HOME_GENERATE_THREAD_RECOMMENDATION_QUESTIONS = 'home_generate_thread_recommendation_questions', + // adjustment + HOME_ADJUST_THREAD_RESPONSE = 'home_adjust_thread_response', + HOME_ADJUST_THREAD_RESPONSE_CANCEL = 'home_adjust_thread_response_cancel', + HOME_ADJUST_THREAD_RESPONSE_RERUN = 'home_adjust_thread_response_rerun', + HOME_ADJUST_THREAD_RESPONSE_WITH_SQL = 'home_adjust_thread_response_with_sql', + // event after ask HOME_CREATE_VIEW = 'home_create_view', HOME_PREVIEW_ANSWER = 'home_preview_answer', diff --git a/wren-ui/src/apollo/server/utils/error.ts b/wren-ui/src/apollo/server/utils/error.ts index 58fc196401..87b8069033 100644 --- a/wren-ui/src/apollo/server/utils/error.ts +++ b/wren-ui/src/apollo/server/utils/error.ts @@ -41,6 +41,9 @@ export enum GeneralErrorCodes { GENERATE_QUESTIONS_ERROR = 'GENERATE_QUESTIONS_ERROR', INVALID_SQL_ERROR = 'INVALID_SQL_ERROR', + // wren engine error + WREN_ENGINE_ERROR = 'WREN_ENGINE_ERROR', + // asking task error // when rerun from cancelled, the task is identified as general or misleading query IDENTIED_AS_GENERAL = 'IDENTIED_AS_GENERAL', diff --git a/wren-ui/src/common.ts b/wren-ui/src/common.ts index aee4518976..8337b2b6c2 100644 --- a/wren-ui/src/common.ts +++ b/wren-ui/src/common.ts @@ -128,6 +128,7 @@ export const initComponents = () => { queryService, mdlService, askingTaskTracker, + askingTaskRepository, }); const dashboardService = new DashboardService({ projectService, diff --git a/wren-ui/src/components/diagram/CustomDropdown.tsx b/wren-ui/src/components/diagram/CustomDropdown.tsx index 646ab2e0c0..9b3fd0f68f 100644 --- a/wren-ui/src/components/diagram/CustomDropdown.tsx +++ b/wren-ui/src/components/diagram/CustomDropdown.tsx @@ -1,4 +1,5 @@ import React from 'react'; +import styled from 'styled-components'; import { Dropdown, Menu } from 'antd'; import { ItemType } from 'antd/lib/menu/hooks/useItems'; import { MORE_ACTION, NODE_TYPE } from '@/utils/enum'; @@ -6,6 +7,8 @@ import EditOutlined from '@ant-design/icons/EditOutlined'; import ReloadOutlined from '@ant-design/icons/ReloadOutlined'; import EyeInvisibleOutlined from '@ant-design/icons/EyeInvisibleOutlined'; import EyeOutlined from '@ant-design/icons/EyeOutlined'; +import CodeFilled from '@ant-design/icons/CodeFilled'; +import { EditSVG } from '@/utils/svgs'; import { DeleteCalculatedFieldModal, DeleteRelationshipModal, @@ -16,16 +19,23 @@ import { DeleteInstructionModal, } from '@/components/modals/DeleteModal'; +const StyledMenu = styled(Menu)` + .ant-dropdown-menu-item:not(.ant-dropdown-menu-item-disabled) { + color: var(--gray-8); + } +`; + interface Props { [key: string]: any; - onMoreClick: (type: MORE_ACTION) => void; + onMoreClick: (type: MORE_ACTION | { type: MORE_ACTION; data: any }) => void; onMenuEnter?: (event: React.MouseEvent) => void; children: React.ReactNode; + onDropdownVisibleChange?: (visible: boolean) => void; } const makeDropdown = (getItems: (props: Props) => ItemType[]) => (props: Props) => { - const { children, onMenuEnter } = props; + const { children, onMenuEnter, onDropdownVisibleChange } = props; const items = getItems(props); @@ -34,12 +44,13 @@ const makeDropdown = trigger={['click']} overlayStyle={{ minWidth: 100, userSelect: 'none' }} overlay={ - e.domEvent.stopPropagation()} items={items} onMouseEnter={onMenuEnter} /> } + onVisibleChange={onDropdownVisibleChange} > {children} @@ -53,7 +64,7 @@ export const ModelDropdown = makeDropdown((props: Props) => { { label: ( <> - + Update Columns ), @@ -102,7 +113,7 @@ export const ColumnDropdown = makeDropdown((props: Props) => { { label: ( <> - + Edit ), @@ -128,12 +139,12 @@ export const DashboardItemDropdown = makeDropdown((props: Props) => { { label: isHideLegend ? ( <> - + Show categories ) : ( <> - {} + {} Hide categories ), @@ -143,7 +154,7 @@ export const DashboardItemDropdown = makeDropdown((props: Props) => { { label: ( <> - + Refresh ), @@ -175,7 +186,7 @@ export const SQLPairDropdown = makeDropdown( { label: ( <> - + View ), @@ -189,7 +200,7 @@ export const SQLPairDropdown = makeDropdown( { label: ( <> - + Edit ), @@ -234,7 +245,7 @@ export const InstructionDropdown = makeDropdown( { label: ( <> - + View ), @@ -248,7 +259,7 @@ export const InstructionDropdown = makeDropdown( { label: ( <> - + Edit ), @@ -281,3 +292,37 @@ export const InstructionDropdown = makeDropdown( return items; }, ); + +export const AdjustAnswerDropdown = makeDropdown( + ( + props: Props & { + onMoreClick: (payload: { type: MORE_ACTION; data: any }) => void; + }, + ) => { + const { onMoreClick, data } = props; + const items: ItemType[] = [ + { + label: 'Adjust steps', + icon: , + disabled: !data.sqlGenerationReasoning, + key: 'adjust-steps', + onClick: () => + onMoreClick({ + type: MORE_ACTION.ADJUST_STEPS, + data, + }), + }, + { + label: 'Adjust SQL', + icon: , + key: 'adjust-sql', + onClick: () => + onMoreClick({ + type: MORE_ACTION.ADJUST_SQL, + data, + }), + }, + ]; + return items; + }, +); diff --git a/wren-ui/src/components/editor/CodeBlock.tsx b/wren-ui/src/components/editor/CodeBlock.tsx index 169fa0e2c5..251f4a1b38 100644 --- a/wren-ui/src/components/editor/CodeBlock.tsx +++ b/wren-ui/src/components/editor/CodeBlock.tsx @@ -1,8 +1,10 @@ import { useEffect } from 'react'; -import { Typography } from 'antd'; +import { Button, Typography } from 'antd'; import styled from 'styled-components'; import '@/components/editor/AceEditor'; import { Loading } from '@/components/PageLoading'; +import CheckOutlined from '@ant-design/icons/CheckOutlined'; +import CopyOutlined from '@ant-design/icons/CopyOutlined'; const Block = styled.div<{ inline?: boolean; maxHeight?: string }>` position: relative; @@ -41,11 +43,15 @@ const Block = styled.div<{ inline?: boolean; maxHeight?: string }>` const CopyText = styled(Typography.Text)` position: absolute; top: 8px; - right: 8px; + right: 20px; font-size: 0; .ant-typography-copy { font-size: 12px; } + + .ant-btn:not(:hover) { + color: var(--gray-8); + } `; interface Props { @@ -112,7 +118,27 @@ export default function CodeBlock(props: Props) {
{lines} - {copyable && {code}} + {copyable && ( + } + size="small" + style={{ backgroundColor: 'transparent' }} + />, +
diff --git a/wren-ui/src/components/editor/MarkdownBlock.tsx b/wren-ui/src/components/editor/MarkdownBlock.tsx index 301f631bec..b16752542c 100644 --- a/wren-ui/src/components/editor/MarkdownBlock.tsx +++ b/wren-ui/src/components/editor/MarkdownBlock.tsx @@ -51,6 +51,23 @@ const ReactMarkdownBlock = styled(ReactMarkdown)` border-collapse: collapse; margin-bottom: 16px; } + ol, + ul, + dl { + padding-inline-start: 20px; + } + h1 code, + h2 code, + h3 code, + h4 code, + li code, + p code { + font-size: 12px; + background: var(--gray-4); + color: var(--gray-8); + padding: 2px 4px; + border-radius: 4px; + } `; export default function MarkdownBlock(props: { content: string }) { diff --git a/wren-ui/src/components/editor/MarkdownEditor.tsx b/wren-ui/src/components/editor/MarkdownEditor.tsx new file mode 100644 index 0000000000..8e6ee4e0c4 --- /dev/null +++ b/wren-ui/src/components/editor/MarkdownEditor.tsx @@ -0,0 +1,254 @@ +import clsx from 'clsx'; +import { Button, Mentions, Typography } from 'antd'; +import styled from 'styled-components'; +import { useState, useContext, useRef } from 'react'; +import ReadOutlined from '@ant-design/icons/ReadOutlined'; +import EditOutlined from '@ant-design/icons/EditOutlined'; +import { nextTick } from '@/utils/time'; +import { Mention } from '@/hooks/useMentions'; +import { FormItemInputContext } from 'antd/lib/form/context'; +import MarkdownBlock from './MarkdownBlock'; + +const Wrapper = styled.div` + transition: all 0.3s cubic-bezier(0.645, 0.045, 0.355, 1); + + &:hover { + border-color: var(--geekblue-5) !important; + } + + &.adm-markdown-editor-error { + border-color: var(--red-5) !important; + + .adm-markdown-editor-length { + color: var(--red-5) !important; + } + } + &:not(.adm-markdown-editor-error).adm-markdown-editor-focused { + border-color: var(--geekblue-5) !important; + box-shadow: 0 0 0 2px rgba(47, 84, 235, 0.2); + } + + &.adm-markdown-editor-focused.adm-markdown-editor-error { + borer-color: var(--red-4) !important; + box-shadow: 0 0 0 2px rgba(255, 77, 79, 0.2); + } +`; + +const OverflowContainer = styled.div` + overflow-y: auto; + max-height: 318px; +`; + +const LinkButton = styled(Button)` + color: var(--gray-7); +`; + +const StyledTextArea = styled(Mentions)` + border: none; + border-radius: 0; + + textarea { + padding: 16px 16px 16px 20px; + } +`; + +interface Props { + value?: string; + onChange?: (value: string) => void; + maxLength?: number; + autoFocus?: boolean; + mentions?: Mention[]; +} + +const MENTION_PREFIX = '@'; + +const MentionOption = (props: Mention) => { + return ( + +
+
+ {props.icon} + + {props.label} + +
+ {props.meta && ( +
+ + ({props.meta}) + + {props.nodeType} +
+ )} +
+
+ ); +}; + +export default function MarkdownEditor(props: Props) { + const { value, onChange, maxLength, autoFocus, mentions } = props; + const $wrapper = useRef(null); + const $textarea = useRef( + null, + ); + const [focused, setFocused] = useState(false); + const [isPreviewMode, setIsPreviewMode] = useState(false); + + const formItemContext = useContext(FormItemInputContext); + const { status } = formItemContext; + + const change = (targetValue: string) => { + onChange?.(targetValue); + }; + + const select = (option: Mention) => { + const textarea = $textarea.current?.textarea; + if (!textarea) return; + + // go to the start of the mention + const mentionStart = ( + value?.slice(0, textarea.selectionStart) || '' + ).lastIndexOf(MENTION_PREFIX); + const start = mentionStart >= 0 ? mentionStart : textarea.selectionStart; + const end = textarea.selectionEnd; + const newValue = value?.slice(0, start) + option.value + value?.slice(end); + // update the value and move the cursor + onChange?.(newValue || ''); + nextTick().then(() => { + textarea.selectionStart = textarea.selectionEnd = + start + option.value.length; + }); + }; + + const keydown = (e: React.KeyboardEvent) => { + if (e.key === 'Tab') { + e.preventDefault(); + const textarea = e.currentTarget; + const start = textarea.selectionStart; + const end = textarea.selectionEnd; + // Set the value with a tab character or spaces + const tabCharacter = ' '; // Use '\t' for a tab character or spaces for spaces + const newValue = + value?.slice(0, start) + tabCharacter + value?.slice(end); + // update the value and move the cursor + onChange?.(newValue || ''); + nextTick().then(() => { + textarea.selectionStart = textarea.selectionEnd = + start + tabCharacter.length; + }); + } + if (e.key === '`') { + const textarea = e.currentTarget; + const start = textarea.selectionStart; + const end = textarea.selectionEnd; + + if (start !== end) { + e.preventDefault(); + const selection = `\`${value?.slice(start, end)}\``; + const newValue = value?.slice(0, start) + selection + value?.slice(end); + // update the value and move the cursor + onChange?.(newValue || ''); + nextTick().then(() => { + textarea.selectionStart = textarea.selectionEnd = + start + selection.length; + }); + } + } + if (e.key === 'ArrowDown' || e.key === 'ArrowUp') { + // check if the mention dropdown menu exist + const dropdownMenu = $wrapper.current?.querySelector( + '.ant-mentions-dropdown-menu', + ); + if (dropdownMenu) { + // delay to make sure the menu active item is rendered + nextTick().then(() => { + const activeItem = dropdownMenu.querySelector( + '.ant-mentions-dropdown-menu-item-active', + ) as HTMLLIElement; + if (activeItem) { + const menuRect = dropdownMenu.getBoundingClientRect(); + const activeRect = activeItem.getBoundingClientRect(); + // check if active item is outside viewport + if (activeRect.bottom > menuRect.bottom) { + // scroll down + dropdownMenu.scrollTo({ + top: + dropdownMenu.scrollTop + + (activeRect.bottom - menuRect.bottom), + behavior: 'smooth', + }); + } else if (activeRect.top < menuRect.top) { + // scroll up + dropdownMenu.scrollTo({ + top: dropdownMenu.scrollTop - (menuRect.top - activeRect.top), + behavior: 'smooth', + }); + } + } + }); + } + } + }; + + return ( + +
+
+ {maxLength ? ( + <> + {value?.length} / {maxLength} characters + + ) : ( + <>{value?.length} characters + )} +
+ : } + type="link" + size="small" + onClick={() => setIsPreviewMode(!isPreviewMode)} + > + {isPreviewMode ? 'Edit mode' : 'Read mode'} + +
+ + {isPreviewMode ? ( + + ) : ( + $wrapper?.current} + onChange={change} + onSelect={select} + onKeyDown={keydown} + onFocus={() => setFocused(true)} + onBlur={() => setFocused(false)} + value={value} + prefix={MENTION_PREFIX} + maxLength={maxLength} + > + {(mentions || []).map(MentionOption)} + + )} + +
+ ); +} diff --git a/wren-ui/src/components/modals/AdjustReasoningStepsModal.tsx b/wren-ui/src/components/modals/AdjustReasoningStepsModal.tsx new file mode 100644 index 0000000000..f69086e6b7 --- /dev/null +++ b/wren-ui/src/components/modals/AdjustReasoningStepsModal.tsx @@ -0,0 +1,190 @@ +import { useEffect, useMemo } from 'react'; +import { keyBy } from 'lodash'; +import styled from 'styled-components'; +import { Form, Modal, Select, Tag } from 'antd'; +import QuestionCircleOutlined from '@ant-design/icons/QuestionCircleOutlined'; +import { ERROR_TEXTS } from '@/utils/error'; +import useMentions from '@/hooks/useMentions'; +import { ModalAction } from '@/hooks/useModalAction'; +import MarkdownEditor from '@/components/editor/MarkdownEditor'; +import { useListModelsQuery } from '@/apollo/client/graphql/model.generated'; + +const MultiSelect = styled(Select)` + .ant-select-selector { + padding-top: 3px; + } + .ant-tag { + padding: 3px 5px; + margin-right: 3px; + margin-bottom: 3px; + } +`; + +const TagText = styled.div` + line-height: 16px; +`; + +type Props = ModalAction<{ + responseId: number; + retrievedTables: string[]; + sqlGenerationReasoning: string; +}> & { + loading?: boolean; +}; + +export default function AdjustReasoningStepsModal(props: Props) { + const { visible, defaultValue, loading, onSubmit, onClose } = props; + const [form] = Form.useForm(); + + const { mentions } = useMentions({ includeColumns: true, skip: !visible }); + const listModelsResult = useListModelsQuery({ skip: !visible }); + const modelNameMap = keyBy( + listModelsResult.data?.listModels, + 'referenceName', + ); + const modelOptions = useMemo(() => { + return listModelsResult.data?.listModels.map((model) => ({ + label: model.displayName, + value: model.referenceName, + })); + }, [listModelsResult.data?.listModels]); + + useEffect(() => { + if (!visible) return; + const listModels = listModelsResult.data?.listModels || []; + const retrievedTables = listModels.reduce((result, model) => { + if (defaultValue?.retrievedTables.includes(model.referenceName)) { + console.log(model.referenceName); + result.push({ label: model.displayName, value: model.referenceName }); + } + return result; + }, []); + form.setFieldsValue({ + tables: retrievedTables, + sqlGenerationReasoning: defaultValue?.sqlGenerationReasoning, + }); + }, [form, defaultValue, visible, listModelsResult.data?.listModels]); + + const tagRender = (props) => { + const { value, closable, onClose } = props; + const model = modelNameMap[value]; + return ( + e.stopPropagation()} + closable={closable} + onClose={onClose} + className="d-flex align-center bg-gray-3 border-gray-3" + style={{ maxWidth: 140 }} + > +
+ + {model.displayName} + + + {model.referenceName} + +
+
+ ); + }; + + const reset = () => { + form.resetFields(); + }; + + const submit = async () => { + form + .validateFields() + .then(async (values) => { + await onSubmit({ + responseId: defaultValue.responseId, + data: { + ...values, + tables: values.tables.map((table) => table.value), + }, + }); + onClose(); + }) + .catch(console.error); + }; + + return ( + +
+ + Select the tables needed to answer your question.{' '} + + Tables not selected won't be used in SQL generation. + + + } + > + + + + + Protip: Use @ to choose model in the textarea. + + } + > +
+ Edit the reasoning logic below. Each step should build toward + answering the question accurately. +
+ + + +
+
+
+ ); +} diff --git a/wren-ui/src/components/modals/AdjustSQLModal.tsx b/wren-ui/src/components/modals/AdjustSQLModal.tsx new file mode 100644 index 0000000000..905152ac28 --- /dev/null +++ b/wren-ui/src/components/modals/AdjustSQLModal.tsx @@ -0,0 +1,215 @@ +import { useEffect, useState } from 'react'; +import { Alert, Button, Form, Modal, Typography } from 'antd'; +import InfoCircleOutlined from '@ant-design/icons/InfoCircleOutlined'; +import { ERROR_TEXTS } from '@/utils/error'; +import { ModalAction } from '@/hooks/useModalAction'; +import SQLEditor from '@/components/editor/SQLEditor'; +import { parseGraphQLError } from '@/utils/errorHandler'; +import ErrorCollapse from '@/components/ErrorCollapse'; +import PreviewData from '@/components/dataPreview/PreviewData'; +import { usePreviewSqlMutation } from '@/apollo/client/graphql/sql.generated'; + +interface AdjustSQLFormValues { + responseId: number; + sql: string; +} + +type Props = ModalAction & { + loading?: boolean; +}; + +export default function AdjustSQLModal(props: Props) { + const { defaultValue, loading, onClose, onSubmit, visible } = props; + + const [form] = Form.useForm(); + const [error, setError] = + useState>(null); + const [previewing, setPreviewing] = useState(false); + const [submitting, setSubmitting] = useState(false); + const [showPreview, setShowPreview] = useState(false); + + const [previewSqlMutation, previewSqlResult] = usePreviewSqlMutation(); + + const sqlValue = Form.useWatch('sql', form); + + useEffect(() => { + if (visible) { + form.setFieldsValue({ + sql: defaultValue?.sql, + }); + } + }, [visible, defaultValue]); + + const handleReset = () => { + previewSqlResult.reset(); + setShowPreview(false); + setError(null); + form.resetFields(); + }; + + const onValidateSQL = async () => { + await previewSqlMutation({ + variables: { + data: { + sql: sqlValue, + limit: 1, + dryRun: true, + }, + }, + }); + }; + + const handleError = (error) => { + const graphQLError = parseGraphQLError(error); + setError({ ...graphQLError, shortMessage: 'Invalid SQL syntax' }); + console.error(graphQLError); + }; + + const onPreviewData = async () => { + setError(null); + setPreviewing(true); + try { + await onValidateSQL(); + setShowPreview(true); + await previewSqlMutation({ + variables: { + data: { + sql: sqlValue, + limit: 50, + }, + }, + }); + } catch (error) { + setShowPreview(false); + handleError(error); + } finally { + setPreviewing(false); + } + }; + + const onSubmitButton = () => { + setError(null); + setSubmitting(true); + setShowPreview(false); + form + .validateFields() + .then(async (values) => { + try { + await onValidateSQL(); + await onSubmit({ + responseId: defaultValue?.responseId, + sql: values.sql, + }); + onClose(); + } catch (error) { + handleError(error); + } finally { + setSubmitting(false); + } + }) + .catch((err) => { + setSubmitting(false); + console.error(err); + }); + }; + + const confirmLoading = loading || submitting; + const disabled = !sqlValue; + + return ( + handleReset()} + footer={ +
+
+ + + The SQL statement used here follows Wren SQL, which is + based on ANSI SQL and optimized for Wren AI.{` `} + + Learn more about the syntax. + + +
+
+ + +
+
+ } + > +
+ + + +
+
+ + Data preview (50 rows) + + + {showPreview && ( +
+ +
+ )} +
+ {!!error && ( + } + /> + )} +
+ ); +} diff --git a/wren-ui/src/components/pages/home/preparation/PreparationStatus.tsx b/wren-ui/src/components/pages/home/preparation/PreparationStatus.tsx index 26d4a65616..fcf1cb2700 100644 --- a/wren-ui/src/components/pages/home/preparation/PreparationStatus.tsx +++ b/wren-ui/src/components/pages/home/preparation/PreparationStatus.tsx @@ -5,24 +5,38 @@ import ReloadOutlined from '@ant-design/icons/ReloadOutlined'; import { attachLoading } from '@/utils/helper'; import { getIsFinished } from '@/hooks/useAskPrompt'; import { AskingTaskStatus } from '@/apollo/client/graphql/__types__'; -import type { Props } from './index'; +import type { PreparedTask, Props } from './index'; -export default function PreparationStatus(props: Props) { - const { data, onStopAskingTask, onReRunAskingTask } = props; +export default function PreparationStatus( + props: Props & { preparedTask: PreparedTask }, +) { + const { + data, + preparedTask, + onStopAskingTask, + onReRunAskingTask, + onStopAdjustTask, + onReRunAdjustTask, + } = props; const [stopLoading, setStopLoading] = useState(false); const [reRunLoading, setReRunLoading] = useState(false); - const { askingTask } = data; - const isProcessing = !getIsFinished(askingTask.status); + const isProcessing = !getIsFinished(preparedTask.status); const onCancel = (e) => { e.stopPropagation(); - const stopAskingTask = attachLoading(onStopAskingTask, setStopLoading); - stopAskingTask(askingTask.queryId); + const stopPreparedTask = preparedTask.isAdjustment + ? onStopAdjustTask + : onStopAskingTask; + const stopAskingTask = attachLoading(stopPreparedTask, setStopLoading); + stopAskingTask(preparedTask.queryId); }; const onReRun = (e) => { e.stopPropagation(); - const reRunAskingTask = attachLoading(onReRunAskingTask, setReRunLoading); + const reRunPreparedTask = preparedTask.isAdjustment + ? onReRunAdjustTask + : onReRunAskingTask; + const reRunAskingTask = attachLoading(reRunPreparedTask, setReRunLoading); reRunAskingTask(data); }; @@ -38,7 +52,7 @@ export default function PreparationStatus(props: Props) { Cancel ); - } else if (askingTask.status === AskingTaskStatus.STOPPED) { + } else if (preparedTask.status === AskingTaskStatus.STOPPED) { return ( Cancelled by user @@ -54,9 +68,9 @@ export default function PreparationStatus(props: Props) { ); - } else if (askingTask.status === AskingTaskStatus.FINISHED) { + } else if (preparedTask.status === AskingTaskStatus.FINISHED) { const showView = data.view !== null; - const showSqlPair = !!askingTask?.candidates[0]?.sqlPair; + const showSqlPair = !!preparedTask?.candidates[0]?.sqlPair; return (
{showView || showSqlPair ? '1 step' : '3 steps'} diff --git a/wren-ui/src/components/pages/home/preparation/PreparationSteps.tsx b/wren-ui/src/components/pages/home/preparation/PreparationSteps.tsx index ff3ee70a46..1caeb7866a 100644 --- a/wren-ui/src/components/pages/home/preparation/PreparationSteps.tsx +++ b/wren-ui/src/components/pages/home/preparation/PreparationSteps.tsx @@ -13,7 +13,7 @@ import { ProcessStateMachine, convertAskingTaskToProcessState, } from '@/hooks/useAskProcessState'; -import type { Props } from './index'; +import type { Props, PreparedTask } from './index'; const StyledBadge = styled(Badge)` position: absolute; @@ -47,29 +47,31 @@ const getProcessDot = (processing: boolean) => { ) : null; }; -export default function PreparationSteps(props: Props) { - const { className, data, askingStreamTask, minimized } = props; - const { askingTask, view, sql } = data; +export default function PreparationSteps( + props: Props & { preparedTask: PreparedTask }, +) { + const { className, data, askingStreamTask, minimized, preparedTask } = props; + const { view, sql } = data; const processState = useMemo( - () => convertAskingTaskToProcessState(askingTask), - [askingTask], + () => convertAskingTaskToProcessState(preparedTask), + [preparedTask], ); const isFixedSQL = useMemo(() => { - return sql && askingTask?.invalidSql; - }, [sql, askingTask?.invalidSql]); + return sql && preparedTask?.invalidSql; + }, [sql, preparedTask?.invalidSql]); // displays const showView = !!view; - const showSqlPair = !!askingTask?.candidates[0]?.sqlPair; + const showSqlPair = !!preparedTask?.candidates[0]?.sqlPair; const showRetrieving = retrievingNextStates.includes(processState); const showOrganizing = organizingNextStates.includes(processState); const showGenerating = generatingNextStates.includes(processState); // data - const retrievedTables = askingTask.retrievedTables || []; + const retrievedTables = preparedTask?.retrievedTables || []; const sqlGenerationReasoning = - askingTask.sqlGenerationReasoning || askingStreamTask || ''; + preparedTask?.sqlGenerationReasoning || askingStreamTask || ''; // loadings const retrieving = processState === PROCESS_STATE.SEARCHING; @@ -88,12 +90,20 @@ export default function PreparationSteps(props: Props) { {showRetrieving && ( - + )} {showOrganizing && ( - + )} {showGenerating && ( diff --git a/wren-ui/src/components/pages/home/preparation/index.tsx b/wren-ui/src/components/pages/home/preparation/index.tsx index 9ec6961c45..9a8b9855b4 100644 --- a/wren-ui/src/components/pages/home/preparation/index.tsx +++ b/wren-ui/src/components/pages/home/preparation/index.tsx @@ -10,6 +10,8 @@ import { IPromptThreadStore } from '@/components/pages/home/promptThread/store'; import { ThreadResponse, AskingTaskStatus, + AskingTask, + AdjustmentTask, } from '@/apollo/client/graphql/__types__'; export type Props = IPromptThreadStore['preparation'] & { @@ -18,29 +20,47 @@ export type Props = IPromptThreadStore['preparation'] & { minimized?: boolean; }; +export type PreparedTask = AskingTask & + AdjustmentTask & { isAdjustment: boolean }; + export default function Preparation(props: Props) { const { className, data, minimized, onFixSQLStatement } = props; - const { askingTask, id: responseId, sql } = data; + const { askingTask, adjustmentTask, adjustment, id: responseId, sql } = data; const [isActive, setIsActive] = useState(!sql); + // Adapt askingTask and adjustmentTask for preparation steps + const preparedTask = useMemo(() => { + if (askingTask === null && adjustmentTask === null) return null; + const { payload } = adjustment || {}; + return { + candidates: [], + invalidSql: '', + retrievedTables: payload?.retrievedTables || [], + sqlGenerationReasoning: payload?.sqlGenerationReasoning || '', + isAdjustment: !!adjustmentTask, + ...(askingTask || {}), + ...(adjustmentTask || {}), + } as PreparedTask; + }, [askingTask?.status, adjustmentTask?.status, adjustment?.payload]); + // wrapping up after answer is prepared useEffect(() => { setIsActive(!minimized); }, [minimized]); const error = useMemo(() => { - return askingTask?.error && !sql + return preparedTask?.error && !sql ? { - ...askingTask.error, - invalidSql: askingTask.invalidSql, + ...preparedTask.error, + invalidSql: preparedTask?.invalidSql, fixStatement: (sql: string) => onFixSQLStatement(responseId, sql), } : null; - }, [askingTask?.error, askingTask?.invalidSql, responseId, sql]); + }, [preparedTask, responseId, sql]); - if (askingTask === null) return null; + if (preparedTask === null) return null; - const isStopped = askingTask.status === AskingTaskStatus.STOPPED; + const isStopped = preparedTask.status === AskingTaskStatus.STOPPED; return (
@@ -73,12 +93,16 @@ export default function Preparation(props: Props) { /> Answer preparation steps - +
} > - + diff --git a/wren-ui/src/components/pages/home/preparation/step/Organizing.tsx b/wren-ui/src/components/pages/home/preparation/step/Organizing.tsx index d811c9fbac..cc11d7d15e 100644 --- a/wren-ui/src/components/pages/home/preparation/step/Organizing.tsx +++ b/wren-ui/src/components/pages/home/preparation/step/Organizing.tsx @@ -6,11 +6,12 @@ import { Spinner } from '@/components/PageLoading'; interface Props { stream: string; loading?: boolean; + isAdjustment?: boolean; } export default function Organizing(props: Props) { const $wrapper = useRef(null); - const { stream, loading } = props; + const { stream, loading, isAdjustment } = props; const isDone = stream && !loading; @@ -30,9 +31,13 @@ export default function Organizing(props: Props) { if (isDone) scrollBottom(); }, [isDone]); + const title = isAdjustment + ? 'User-provided reasoning steps applied' + : 'Organizing thoughts'; + return ( <> - Organizing thoughts + {title}
{ @@ -14,15 +15,23 @@ const TagTemplate = ({ name }: { name: string }) => { const TagIterator = makeIterable(TagTemplate); export default function Retrieving(props: Props) { - const { tables, loading } = props; + const { tables, loading, isAdjustment } = props; const data = tables.map((table) => ({ name: table })); + const title = isAdjustment + ? 'User-selected models applied' + : 'Retrieving related models'; + + const modelDescription = isAdjustment ? ( + <>{tables.length} models applied + ) : ( + <>{tables.length} models found + ); + return ( <> - - Retrieving related models - + {title}
{loading ? (
@@ -31,7 +40,7 @@ export default function Retrieving(props: Props) {
) : ( <> -
{tables.length} models found
+
{modelDescription}
)} diff --git a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx b/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx index 46e6545fab..1fc3d759af 100644 --- a/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx +++ b/wren-ui/src/components/pages/home/promptThread/AnswerResult.tsx @@ -7,6 +7,7 @@ import CheckCircleFilled from '@ant-design/icons/CheckCircleFilled'; import CodeFilled from '@ant-design/icons/CodeFilled'; import PieChartFilled from '@ant-design/icons/PieChartFilled'; import MessageOutlined from '@ant-design/icons/MessageOutlined'; +import ShareAltOutlined from '@ant-design/icons/ShareAltOutlined'; import { RobotSVG } from '@/utils/svgs'; import { ANSWER_TAB_KEYS } from '@/utils/enum'; import { canGenerateAnswer } from '@/hooks/useAskPrompt'; @@ -16,7 +17,7 @@ import RecommendedQuestions, { getRecommendedQuestionProps, } from '@/components/pages/home/RecommendedQuestions'; import ViewBlock from '@/components/pages/home/promptThread/ViewBlock'; -import BreakdownAnswer from '@/components/pages/home/promptThread/BreakdownAnswer'; +import ViewSQLTabContent from '@/components/pages/home/promptThread/ViewSQLTabContent'; import TextBasedAnswer, { getAnswerIsFinished, } from '@/components/pages/home/promptThread/TextBasedAnswer'; @@ -27,10 +28,17 @@ import { ThreadResponse, ThreadResponseAnswerDetail, ThreadResponseAnswerStatus, + ThreadResponseAdjustment, + ThreadResponseAdjustmentType, } from '@/apollo/client/graphql/__types__'; const { Title, Text } = Typography; +const adjustmentType = { + [ThreadResponseAdjustmentType.APPLY_SQL]: 'User-provided SQL applied', + [ThreadResponseAdjustmentType.REASONING]: 'Reasoning steps adjusted', +}; + const knowledgeTooltip = ( <> Store this answer as a Question-SQL pair to help Wren AI improve SQL @@ -146,6 +154,26 @@ const renderRecommendedQuestions = ( ); }; +const AdjustmentInformation = (props: { + adjustment: ThreadResponseAdjustment; +}) => { + const { adjustment } = props; + + return ( +
+
+ +
+ Adjusted answer + + {adjustmentType[adjustment.type]} + +
+
+
+ ); +}; + const isNeedGenerateAnswer = (answerDetail: ThreadResponseAnswerDetail) => { const isFinished = getAnswerIsFinished(answerDetail?.status); // it means the background task has not started yet, but answer is pending for generating @@ -164,7 +192,6 @@ export default function AnswerResult(props: Props) { onOpenSaveAsViewModal, onGenerateThreadRecommendedQuestions, onGenerateTextBasedAnswer, - onGenerateBreakdownAnswer, onGenerateChartAnswer, onOpenSaveToKnowledgeModal, // recommend questions @@ -174,13 +201,24 @@ export default function AnswerResult(props: Props) { preparation, } = usePromptThreadStore(); - const { askingTask, answerDetail, breakdownDetail, id, question, sql, view } = - threadResponse; + const { + askingTask, + adjustmentTask, + answerDetail, + breakdownDetail, + id, + question, + sql, + view, + adjustment, + } = threadResponse; const resultStyle = isLastThreadResponse ? { minHeight: 'calc(100vh - (194px))' } : null; + const isAdjustment = !!adjustment; + const recommendedQuestionProps = getRecommendedQuestionProps( recommendedQuestions, showRecommendedQuestions, @@ -196,7 +234,10 @@ export default function AnswerResult(props: Props) { // initialize generate answer useEffect(() => { if (isBreakdownOnly) return; - if (canGenerateAnswer(askingTask) && isNeedGenerateAnswer(answerDetail)) { + if ( + canGenerateAnswer(askingTask, adjustmentTask) && + isNeedGenerateAnswer(answerDetail) + ) { const debouncedGenerateAnswer = debounce( () => { onGenerateTextBasedAnswer(id); @@ -211,16 +252,14 @@ export default function AnswerResult(props: Props) { debouncedGenerateAnswer.cancel(); }; } - }, [isBreakdownOnly, askingTask?.status, answerDetail?.status]); + }, [ + isBreakdownOnly, + askingTask?.status, + adjustmentTask?.status, + answerDetail?.status, + ]); const onTabClick = (activeKey: string) => { - if ( - activeKey === ANSWER_TAB_KEYS.VIEW_SQL && - !threadResponse.breakdownDetail - ) { - onGenerateBreakdownAnswer(id); - } - if (activeKey === ANSWER_TAB_KEYS.CHART && !threadResponse.chartDetail) { onGenerateChartAnswer(id); } @@ -233,6 +272,7 @@ export default function AnswerResult(props: Props) { return (
+ {isAdjustment && } } > - + - ); - } - - return ( - -
-
{description}
- {(steps || []).map((step, index) => ( - - ))} -
-
- ); -} diff --git a/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx b/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx index 63fca1aad3..adb24bbcda 100644 --- a/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx +++ b/wren-ui/src/components/pages/home/promptThread/ChartAnswer.tsx @@ -218,7 +218,7 @@ export default function ChartAnswer(props: AnswerResultProps) { if (error) { return ( -
+
-
+
{chartDetail?.description} {chartSpec ? ( import('@/components/editor/CodeBlock'), { - ssr: false, -}); - -const { Text } = Typography; - -const StyledToolBar = styled.div` - background-color: var(--gray-2); - height: 32px; - padding: 4px 8px; -`; - -const StyledPre = styled.pre<{ showNativeSQL: boolean }>` - .adm_code-block { - ${(props) => (props.showNativeSQL ? 'border-top: none;' : '')} - } -`; - -interface Props { - isViewSQL?: boolean; - isViewFullSQL?: boolean; - isPreviewData?: boolean; - onCloseCollapse: () => void; - onCopyFullSQL?: () => void; - sql: string; - previewDataResult: ComponentProps; - attributes: { - stepNumber: number; - isLastStep: boolean; - }; - nativeSQLResult: NativeSQLResult; - onChangeNativeSQL: (checked: boolean) => void; -} - -export default function CollapseContent(props: Props) { - const { - isViewSQL, - isViewFullSQL, - isPreviewData, - onCloseCollapse, - onCopyFullSQL, - sql, - previewDataResult, - attributes, - onChangeNativeSQL, - nativeSQLResult, - } = props; - const isStepViewSQL = !isViewFullSQL && isViewSQL; - - const { hasNativeSQL, dataSourceType } = nativeSQLResult; - const showNativeSQL = Boolean(attributes.isLastStep) && hasNativeSQL; - - const sqls = - nativeSQLResult.nativeSQLMode && nativeSQLResult.loading === false - ? nativeSQLResult.data - : sql; - - return ( - <> - {(isViewSQL || isViewFullSQL) && ( - - {showNativeSQL && ( - -
- {nativeSQLResult.nativeSQLMode && ( - <> - {DATA_SOURCE_OPTIONS[dataSourceType].label} - - {DATA_SOURCE_OPTIONS[dataSourceType].label} - - - )} -
-
- } - unCheckedChildren={} - className="mr-2" - size="small" - onChange={onChangeNativeSQL} - loading={nativeSQLResult.loading} - /> - - Show original SQL - -
-
- )} - -
- )} - {isPreviewData && ( -
- - ), - }} - /> -
- )} - {(isStepViewSQL || isPreviewData) && ( -
- - {isPreviewData && ( - Showing up to 500 rows - )} -
- )} - {isViewFullSQL && ( - <> - - - - )} - - ); -} diff --git a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx b/wren-ui/src/components/pages/home/promptThread/StepContent.tsx deleted file mode 100644 index f8eca16ede..0000000000 --- a/wren-ui/src/components/pages/home/promptThread/StepContent.tsx +++ /dev/null @@ -1,104 +0,0 @@ -import { useEffect } from 'react'; -import { Button, ButtonProps, Col, Row, Typography } from 'antd'; -import FunctionOutlined from '@ant-design/icons/FunctionOutlined'; -import { BinocularsIcon } from '@/utils/icons'; -import CollapseContent from '@/components/pages/home/promptThread/CollapseContent'; -import useAnswerStepContent from '@/hooks/useAnswerStepContent'; -import { nextTick } from '@/utils/time'; - -const { Text, Paragraph } = Typography; - -interface Props { - fullSql: string; - isLastStep: boolean; - isLastThreadResponse: boolean; - sql: string; - stepIndex: number; - summary: string; - threadResponseId: number; - onInitPreviewDone: () => void; -} - -export default function StepContent(props: Props) { - const { - fullSql, - isLastStep, - isLastThreadResponse, - sql, - stepIndex, - summary, - threadResponseId, - onInitPreviewDone, - } = props; - - const { collapseContentProps, previewDataButtonProps, viewSQLButtonProps } = - useAnswerStepContent({ - fullSql, - isLastStep, - sql, - threadResponseId, - stepIndex, - }); - - const stepNumber = stepIndex + 1; - - const autoTriggerPreviewDataButton = async () => { - await nextTick(); - await previewDataButtonProps.onClick(); - await nextTick(); - onInitPreviewDone(); - }; - - // when is the last step of the last thread response, auto trigger preview data button - useEffect(() => { - if (isLastStep && isLastThreadResponse) { - autoTriggerPreviewDataButton(); - } - }, [isLastStep, isLastThreadResponse]); - - return ( - - -
{stepNumber}.
- - - - {summary} - - + +
{isStreaming && } {status === ThreadResponseAnswerStatus.INTERRUPTED && ( diff --git a/wren-ui/src/components/pages/home/promptThread/ViewSQLTabContent.tsx b/wren-ui/src/components/pages/home/promptThread/ViewSQLTabContent.tsx new file mode 100644 index 0000000000..fad497068e --- /dev/null +++ b/wren-ui/src/components/pages/home/promptThread/ViewSQLTabContent.tsx @@ -0,0 +1,178 @@ +import dynamic from 'next/dynamic'; +import Image from 'next/image'; +import { useEffect } from 'react'; +import styled from 'styled-components'; +import { Button, Divider, Empty, Space, Switch, Typography } from 'antd'; +import CheckOutlined from '@ant-design/icons/CheckOutlined'; +import CloseOutlined from '@ant-design/icons/CloseOutlined'; +import CodeFilled from '@ant-design/icons/CodeFilled'; +import { BinocularsIcon } from '@/utils/icons'; +import { nextTick } from '@/utils/time'; +import useNativeSQL from '@/hooks/useNativeSQL'; +import { DATA_SOURCE_OPTIONS } from '@/components/pages/setup/utils'; +import { Props as AnswerResultProps } from '@/components/pages/home/promptThread/AnswerResult'; +import usePromptThreadStore from '@/components/pages/home/promptThread/store'; +import PreviewData from '@/components/dataPreview/PreviewData'; +import { usePreviewDataMutation } from '@/apollo/client/graphql/home.generated'; + +const CodeBlock = dynamic(() => import('@/components/editor/CodeBlock'), { + ssr: false, +}); + +const { Text } = Typography; + +const StyledPre = styled.pre` + .adm_code-block { + border-top: none; + border-radius: 0px 0px 4px 4px; + } +`; + +const StyledToolBar = styled.div` + background-color: var(--gray-2); + height: 32px; + padding: 4px 8px; + border: 1px solid var(--gray-3); + border-radius: 4px 4px 0px 0px; +`; + +export default function ViewSQLTabContent(props: AnswerResultProps) { + const { isLastThreadResponse, onInitPreviewDone, threadResponse } = props; + + const { onOpenAdjustSQLModal } = usePromptThreadStore(); + const { fetchNativeSQL, nativeSQLResult } = useNativeSQL(); + const [previewData, previewDataResult] = usePreviewDataMutation({ + onError: (error) => console.error(error), + }); + + const onPreviewData = async () => { + await previewData({ variables: { where: { responseId: id } } }); + }; + + const autoTriggerPreviewDataButton = async () => { + await nextTick(); + await onPreviewData(); + await nextTick(); + onInitPreviewDone(); + }; + + // when is the last step of the last thread response, auto trigger preview data button + useEffect(() => { + if (isLastThreadResponse) { + autoTriggerPreviewDataButton(); + } + }, [isLastThreadResponse]); + + const { id, sql } = threadResponse; + + const { hasNativeSQL, dataSourceType } = nativeSQLResult; + const showNativeSQL = hasNativeSQL; + + const sqls = + nativeSQLResult.nativeSQLMode && nativeSQLResult.loading === false + ? nativeSQLResult.data + : sql; + + const onChangeNativeSQL = async (checked: boolean) => { + nativeSQLResult.setNativeSQLMode(checked); + checked && fetchNativeSQL({ variables: { responseId: id } }); + }; + + return ( +
+ + +
+ {nativeSQLResult.nativeSQLMode && ( + <> + {DATA_SOURCE_OPTIONS[dataSourceType].label} + + {DATA_SOURCE_OPTIONS[dataSourceType].label} + + + )} +
+ }> + {showNativeSQL && ( + + )} + + +
+ +
+
+ + {previewDataResult?.data?.previewData && ( +
+ + ), + }} + /> +
+ Showing up to 500 rows +
+
+ )} +
+
+ ); +} diff --git a/wren-ui/src/components/pages/home/promptThread/index.tsx b/wren-ui/src/components/pages/home/promptThread/index.tsx index 10ecb4c1e4..06df156aaf 100644 --- a/wren-ui/src/components/pages/home/promptThread/index.tsx +++ b/wren-ui/src/components/pages/home/promptThread/index.tsx @@ -75,7 +75,7 @@ export default function PromptThread() { const responses = useMemo(() => data?.responses || [], [data?.responses]); const triggerScrollToBottom = (behavior?: ScrollBehavior) => { - if ((data?.responses || []).length <= 1) return; + if (responses.length <= 1) return; const contentLayout = divRef.current?.parentElement; const allElements = (divRef.current?.querySelectorAll( '[data-jsid="answerResult"]', @@ -98,12 +98,12 @@ export default function PromptThread() { }, [router.query]); useEffect(() => { - const lastResponse = data?.responses[data?.responses.length - 1]; + const lastResponse = responses[responses.length - 1]; const isLastResponseFinished = getIsFinished(lastResponse?.askingTask?.status) || getAnswerIsFinished(lastResponse?.answerDetail?.status); triggerScrollToBottom(isLastResponseFinished ? 'auto' : 'smooth'); - }, [data?.responses]); + }, [responses]); const onInitPreviewDone = () => { triggerScrollToBottom(); diff --git a/wren-ui/src/components/pages/home/promptThread/store.tsx b/wren-ui/src/components/pages/home/promptThread/store.tsx index 994af6d7c9..643babc716 100644 --- a/wren-ui/src/components/pages/home/promptThread/store.tsx +++ b/wren-ui/src/components/pages/home/promptThread/store.tsx @@ -15,7 +15,9 @@ export type IPromptThreadStore = { preparation: { askingStreamTask?: string; onStopAskingTask?: (queryId?: string) => Promise; + onStopAdjustTask?: (queryId?: string) => Promise; onReRunAskingTask?: (threadResponse: ThreadResponse) => Promise; + onReRunAdjustTask?: (threadResponse: ThreadResponse) => Promise; onFixSQLStatement?: (responseId: number, sql: string) => Promise; }; onOpenSaveAsViewModal: (data: { sql: string; responseId: number }) => void; @@ -25,7 +27,6 @@ export type IPromptThreadStore = { }: SelectQuestionProps) => Promise; onGenerateThreadRecommendedQuestions: () => Promise; onGenerateTextBasedAnswer: (responseId: number) => Promise; - onGenerateBreakdownAnswer: (responseId: number) => Promise; onGenerateChartAnswer: (responseId: number) => Promise; onAdjustChartAnswer: ( responseId: number, @@ -35,6 +36,12 @@ export type IPromptThreadStore = { data: { sql: string; question: string }, payload: { isCreateMode: boolean }, ) => void; + onOpenAdjustReasoningStepsModal: (data: { + responseId: number; + retrievedTables: string[]; + sqlGenerationReasoning: string; + }) => void; + onOpenAdjustSQLModal: (data: { responseId: number; sql: string }) => void; }; // Register store provider diff --git a/wren-ui/src/hooks/useAdjustAnswer.tsx b/wren-ui/src/hooks/useAdjustAnswer.tsx new file mode 100644 index 0000000000..2cc859508d --- /dev/null +++ b/wren-ui/src/hooks/useAdjustAnswer.tsx @@ -0,0 +1,159 @@ +import { useEffect, useMemo } from 'react'; +import { cloneDeep } from 'lodash'; +import { ApolloClient, NormalizedCacheObject } from '@apollo/client'; +import { THREAD } from '@/apollo/client/graphql/home'; +import { nextTick } from '@/utils/time'; +import { + useAdjustThreadResponseMutation, + useCancelAdjustmentTaskMutation, + useRerunAdjustmentTaskMutation, + useThreadResponseLazyQuery, +} from '@/apollo/client/graphql/home.generated'; +import { + AskingTaskStatus, + DetailedThread, + ThreadResponse, +} from '@/apollo/client/graphql/__types__'; + +export const getIsFinished = (status: AskingTaskStatus) => + [ + AskingTaskStatus.FINISHED, + AskingTaskStatus.FAILED, + AskingTaskStatus.STOPPED, + ].includes(status); + +const handleUpdateThreadCache = ( + threadId: number, + threadResponse: ThreadResponse, + client: ApolloClient, +) => { + const result = client.cache.readQuery<{ thread: DetailedThread }>({ + query: THREAD, + variables: { threadId }, + }); + + if (result?.thread) { + client.cache.updateQuery( + { + query: THREAD, + variables: { threadId }, + }, + (existingData) => { + const isNewResponse = !existingData.thread.responses + .map((r) => r.id) + .includes(threadResponse.id); + return { + thread: { + ...existingData.thread, + responses: isNewResponse + ? [...existingData.thread.responses, threadResponse] + : existingData.thread.responses.map((response) => { + return response.id === threadResponse.id + ? cloneDeep(threadResponse) + : response; + }), + }, + }; + }, + ); + } +}; + +export default function useAdjustAnswer(threadId?: number) { + const [cancelAdjustmentTask] = useCancelAdjustmentTaskMutation(); + const [rerunAdjustmentTask] = useRerunAdjustmentTaskMutation(); + const [adjustThreadResponse, adjustThreadResponseResult] = + useAdjustThreadResponseMutation(); + const [fetchThreadResponse, threadResponseResult] = + useThreadResponseLazyQuery({ + pollInterval: 1000, + }); + + const loading = adjustThreadResponseResult.loading; + + const adjustmentTask = useMemo(() => { + return threadResponseResult.data?.threadResponse.adjustmentTask || null; + }, [threadResponseResult.data]); + + const data = useMemo(() => { + return { + adjustmentTask, + }; + }, [adjustmentTask]); + + useEffect(() => { + const isFinished = getIsFinished(adjustmentTask?.status); + if (isFinished) threadResponseResult.stopPolling(); + }, [adjustmentTask?.status]); + + const onAdjustReasoningSteps = async ( + responseId: number, + input: { tables: string[]; sqlGenerationReasoning: string }, + ) => { + const response = await adjustThreadResponse({ + variables: { + responseId, + data: { + tables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + }, + }, + }); + + // start polling new thread response + const nextThreadResponse = response.data?.adjustThreadResponse; + await fetchThreadResponse({ + variables: { responseId: nextThreadResponse.id }, + }); + + // update new thread response to cache + handleUpdateThreadCache( + threadId, + nextThreadResponse, + threadResponseResult.client, + ); + }; + + const onAdjustSQL = async (responseId: number, sql: string) => { + const response = await adjustThreadResponse({ + variables: { responseId, data: { sql } }, + }); + + // update thread cache + const nextThreadResponse = response.data?.adjustThreadResponse; + handleUpdateThreadCache( + threadId, + nextThreadResponse, + threadResponseResult.client, + ); + + // It won't have adjusmentTask, no need to fetch + }; + + const onStop = async (queryId?: string) => { + const taskId = + queryId || + adjustThreadResponseResult.data?.adjustThreadResponse?.adjustmentTask + ?.queryId; + if (taskId) { + await cancelAdjustmentTask({ variables: { taskId } }); + // waiting for polling fetching stop + await nextTick(1000); + } + }; + + const onReRun = async (threadResponse: ThreadResponse) => { + const responseId = threadResponse.id; + await rerunAdjustmentTask({ variables: { responseId } }); + await fetchThreadResponse({ variables: { responseId } }); + }; + + return { + data, + loading, + onAdjustReasoningSteps, + onAdjustSQL, + onStop, + onReRun, + }; +} diff --git a/wren-ui/src/hooks/useAnswerStepContent.tsx b/wren-ui/src/hooks/useAnswerStepContent.tsx deleted file mode 100644 index 545f0427b0..0000000000 --- a/wren-ui/src/hooks/useAnswerStepContent.tsx +++ /dev/null @@ -1,133 +0,0 @@ -import { useState } from 'react'; -import copy from 'copy-to-clipboard'; -import { message } from 'antd'; -import { COLLAPSE_CONTENT_TYPE } from '@/utils/enum'; -import useNativeSQL from '@/hooks/useNativeSQL'; -import { usePreviewBreakdownDataMutation } from '@/apollo/client/graphql/home.generated'; - -const getTextButton = (isActive: boolean) => ({ - type: 'text', - className: `d-inline-flex align-center mr-2 ${isActive ? 'gray-9' : 'gray-6'}`, -}); - -function getButtonsProps({ - isLastStep, - isPreviewData, - isViewSQL, - previewDataProps, - onViewSQL, - onPreviewData, -}: { - isLastStep: boolean; - isPreviewData: boolean; - isViewSQL: boolean; - previewDataProps: { loading: boolean }; - onViewSQL: () => void; - onPreviewData: () => Promise; -}) { - const previewDataButtonText = 'Prevew data'; - const viewSQLButtonText = isLastStep ? 'View full SQL' : 'View SQL'; - const previewDataButtonProps = isLastStep - ? { type: 'primary', className: 'mr-2' } - : getTextButton(isPreviewData); - const viewSQLButtonProps = isLastStep ? {} : getTextButton(isViewSQL); - - return { - viewSQLButtonProps: { - ...viewSQLButtonProps, - children: viewSQLButtonText, - onClick: onViewSQL, - }, - previewDataButtonProps: { - ...previewDataButtonProps, - children: previewDataButtonText, - loading: previewDataProps.loading, - onClick: onPreviewData, - }, - }; -} - -export default function useAnswerStepContent({ - fullSql, - isLastStep, - sql, - stepIndex, - threadResponseId, -}: { - fullSql: string; - isLastStep: boolean; - sql: string; - stepIndex: number; - threadResponseId: number; -}) { - const { fetchNativeSQL, nativeSQLResult } = useNativeSQL(); - - const [collapseContentType, setCollapseContentType] = - useState(COLLAPSE_CONTENT_TYPE.NONE); - - const [previewData, previewDataResult] = usePreviewBreakdownDataMutation({ - onError: (error) => console.error(error), - }); - - const onViewSQL = () => - setCollapseContentType(COLLAPSE_CONTENT_TYPE.VIEW_SQL); - - const onPreviewData = async () => { - setCollapseContentType(COLLAPSE_CONTENT_TYPE.PREVIEW_DATA); - nativeSQLResult.setNativeSQLMode(false); - await previewData({ - variables: { where: { responseId: threadResponseId, stepIndex } }, - }); - }; - - const onCloseCollapse = () => { - setCollapseContentType(COLLAPSE_CONTENT_TYPE.NONE); - nativeSQLResult.setNativeSQLMode(false); - }; - - const onCopyFullSQL = () => { - copy(nativeSQLResult.nativeSQLMode ? nativeSQLResult.data : fullSql); - message.success('Copied SQL to clipboard.'); - }; - - const onChangeNativeSQL = async (checked: boolean) => { - nativeSQLResult.setNativeSQLMode(checked); - checked && fetchNativeSQL({ variables: { responseId: threadResponseId } }); - }; - - const isViewSQL = collapseContentType === COLLAPSE_CONTENT_TYPE.VIEW_SQL; - const isPreviewData = - collapseContentType === COLLAPSE_CONTENT_TYPE.PREVIEW_DATA; - const previewDataLoading = previewDataResult.loading; - const answerButtonsProps = getButtonsProps({ - isLastStep, - isPreviewData, - isViewSQL, - onPreviewData, - onViewSQL, - previewDataProps: { - loading: previewDataLoading, - }, - }); - const isViewFullSQL = isLastStep && isViewSQL; - const displayedSQL = isLastStep ? fullSql : sql; - - return { - ...answerButtonsProps, - collapseContentProps: { - isPreviewData, - isViewSQL, - isViewFullSQL, - sql: displayedSQL, - previewDataResult: { - error: previewDataResult.error, - loading: previewDataLoading, - previewData: previewDataResult?.data?.previewBreakdownData, - }, - nativeSQLResult, - onCopyFullSQL, - onCloseCollapse, - onChangeNativeSQL, - }, - }; -} diff --git a/wren-ui/src/hooks/useAskPrompt.tsx b/wren-ui/src/hooks/useAskPrompt.tsx index e9f242874b..3c5d741306 100644 --- a/wren-ui/src/hooks/useAskPrompt.tsx +++ b/wren-ui/src/hooks/useAskPrompt.tsx @@ -1,6 +1,7 @@ import { useCallback, useEffect, useMemo, useState } from 'react'; import { cloneDeep, uniq } from 'lodash'; import { + AdjustmentTask, AskingTask, AskingTaskStatus, AskingTaskType, @@ -36,8 +37,13 @@ export const getIsFinished = (status: AskingTaskStatus) => AskingTaskStatus.STOPPED, ].includes(status); -export const canGenerateAnswer = (askingTask: AskingTask) => - askingTask === null || askingTask?.status === AskingTaskStatus.FINISHED; +export const canGenerateAnswer = ( + askingTask: AskingTask, + adjustmentTask: AdjustmentTask, +) => + (askingTask === null && adjustmentTask === null) || + askingTask?.status === AskingTaskStatus.FINISHED || + adjustmentTask?.status === AskingTaskStatus.FINISHED; export const canFetchThreadResponse = (askingTask: AskingTask) => askingTask !== null && diff --git a/wren-ui/src/hooks/useDropdown.tsx b/wren-ui/src/hooks/useDropdown.tsx new file mode 100644 index 0000000000..24905c1a64 --- /dev/null +++ b/wren-ui/src/hooks/useDropdown.tsx @@ -0,0 +1,15 @@ +import { useState } from 'react'; + +export default function useDropdown() { + const [visible, setVisible] = useState(false); + + const onVisibleChange = (visible: boolean) => setVisible(visible); + + const onCloseDropdownMenu = () => setVisible(false); + + return { + visible, + onVisibleChange, + onCloseDropdownMenu, + }; +} diff --git a/wren-ui/src/hooks/useMentions.tsx b/wren-ui/src/hooks/useMentions.tsx new file mode 100644 index 0000000000..bb47e27abe --- /dev/null +++ b/wren-ui/src/hooks/useMentions.tsx @@ -0,0 +1,62 @@ +import { useMemo } from 'react'; +import { capitalize } from 'lodash'; +import { useDiagramQuery } from '@/apollo/client/graphql/diagram.generated'; +import { getNodeTypeIcon } from '@/utils/nodeType'; +import { + DiagramModel, + DiagramView, + DiagramModelField, + DiagramViewField, +} from '@/apollo/client/graphql/__types__'; + +type Model = DiagramModel | DiagramView; +type Field = DiagramModelField | DiagramViewField; + +interface Props { + skip?: boolean; + includeColumns?: boolean; +} + +const convertMention = (item: (Model | Field) & { meta?: string }) => { + return { + id: `${item.id}-${item.referenceName}`, + label: item.displayName, + value: item.referenceName, + nodeType: capitalize(item.nodeType), + meta: item.meta, + icon: getNodeTypeIcon( + { nodeType: item.nodeType, type: (item as Field).type }, + { className: 'gray-8 mr-2' }, + ), + }; +}; + +export type Mention = ReturnType; + +export default function useMentions(props: Props) { + const { includeColumns, skip } = props; + const { data } = useDiagramQuery({ skip }); + + // handle mentions data + const mentions = useMemo(() => { + const models = data?.diagram.models || []; + const views = data?.diagram.views || []; + + return [...models, ...views].reduce((result, item) => { + result.push(convertMention(item)); + if (includeColumns) { + item.fields.forEach((field) => { + result.push( + convertMention({ + ...field, + meta: `${item.displayName}.${field.displayName}`, + }), + ); + }); + } + return result; + }, [] as Mention[]); + }, [data?.diagram, includeColumns]); + + return { mentions }; +} diff --git a/wren-ui/src/pages/home/[id].tsx b/wren-ui/src/pages/home/[id].tsx index 9a8c3704d6..04995a9919 100644 --- a/wren-ui/src/pages/home/[id].tsx +++ b/wren-ui/src/pages/home/[id].tsx @@ -19,10 +19,13 @@ import useAskPrompt, { canFetchThreadResponse, isRecommendedFinished, } from '@/hooks/useAskPrompt'; +import useAdjustAnswer from '@/hooks/useAdjustAnswer'; import useModalAction from '@/hooks/useModalAction'; import PromptThread from '@/components/pages/home/promptThread'; import SaveAsViewModal from '@/components/modals/SaveAsViewModal'; import QuestionSQLPairModal from '@/components/modals/QuestionSQLPairModal'; +import AdjustReasoningStepsModal from '@/components/modals/AdjustReasoningStepsModal'; +import AdjustSQLModal from '@/components/modals/AdjustSQLModal'; import { getAnswerIsFinished } from '@/components/pages/home/promptThread/TextBasedAnswer'; import { getIsChartFinished } from '@/components/pages/home/promptThread/ChartAnswer'; import { PromptThreadProvider } from '@/components/pages/home/promptThread/store'; @@ -34,7 +37,6 @@ import { useGenerateThreadRecommendationQuestionsMutation, useGetThreadRecommendationQuestionsLazyQuery, useGenerateThreadResponseAnswerMutation, - useGenerateThreadResponseBreakdownMutation, useGenerateThreadResponseChartMutation, useAdjustThreadResponseChartMutation, } from '@/apollo/client/graphql/home.generated'; @@ -54,25 +56,18 @@ const getThreadResponseIsFinished = (threadResponse: ThreadResponse) => { // false make it keep polling when the text based answer is default needed. let isAnswerFinished = isBreakdownOnly ? null : false; - let isBreakdownFinished = null; let isChartFinished = null; // answerDetail status can be FAILED before getting queryId from Wren AI adapter if (answerDetail?.queryId || answerDetail?.status) { isAnswerFinished = getAnswerIsFinished(answerDetail?.status); } - if (breakdownDetail?.queryId) { - isBreakdownFinished = getIsFinished(breakdownDetail?.status); - } + if (chartDetail?.queryId) { isChartFinished = getIsChartFinished(chartDetail?.status); } // if equal false, it means it has task & the task is not finished - return ( - isAnswerFinished !== false && - isBreakdownFinished !== false && - isChartFinished !== false - ); + return isAnswerFinished !== false && isChartFinished !== false; }; export default function HomeThread() { @@ -82,8 +77,11 @@ export default function HomeThread() { const homeSidebar = useHomeSidebar(); const threadId = useMemo(() => Number(params?.id) || null, [params]); const askPrompt = useAskPrompt(threadId); + const adjustAnswer = useAdjustAnswer(threadId); const saveAsViewModal = useModalAction(); const questionSqlPairModal = useModalAction(); + const adjustReasoningStepsModal = useModalAction(); + const adjustSqlModal = useModalAction(); const [showRecommendedQuestions, setShowRecommendedQuestions] = useState(false); @@ -150,9 +148,6 @@ export default function HomeThread() { const [generateThreadResponseAnswer] = useGenerateThreadResponseAnswerMutation(); - const [generateThreadResponseBreakdown] = - useGenerateThreadResponseBreakdownMutation(); - const [generateThreadResponseChart] = useGenerateThreadResponseChartMutation(); const [adjustThreadResponseChart] = useAdjustThreadResponseChartMutation(); @@ -189,13 +184,6 @@ export default function HomeThread() { fetchThreadResponse({ variables: { responseId } }); }; - const onGenerateThreadResponseBreakdown = async (responseId: number) => { - await generateThreadResponseBreakdown({ - variables: { responseId }, - }); - fetchThreadResponse({ variables: { responseId } }); - }; - const onGenerateThreadResponseChart = async (responseId: number) => { await generateThreadResponseChart({ variables: { responseId } }); fetchThreadResponse({ variables: { responseId } }); @@ -317,16 +305,19 @@ export default function HomeThread() { askingStreamTask: askPrompt.data?.askingStreamTask, onStopAskingTask: askPrompt.onStop, onReRunAskingTask: askPrompt.onReRun, + onStopAdjustTask: adjustAnswer.onStop, + onReRunAdjustTask: adjustAnswer.onReRun, onFixSQLStatement, }, onOpenSaveAsViewModal: saveAsViewModal.openModal, onSelectRecommendedQuestion: onCreateResponse, onGenerateThreadRecommendedQuestions: onGenerateThreadRecommendedQuestions, onGenerateTextBasedAnswer: onGenerateThreadResponseAnswer, - onGenerateBreakdownAnswer: onGenerateThreadResponseBreakdown, onGenerateChartAnswer: onGenerateThreadResponseChart, onAdjustChartAnswer: onAdjustThreadResponseChart, onOpenSaveToKnowledgeModal: questionSqlPairModal.openModal, + onOpenAdjustReasoningStepsModal: adjustReasoningStepsModal.openModal, + onOpenAdjustSQLModal: adjustSqlModal.openModal, }; return ( @@ -359,6 +350,27 @@ export default function HomeThread() { await createSqlPairMutation({ variables: { data } }); }} /> + + { + await adjustAnswer.onAdjustReasoningSteps( + values.responseId, + values.data, + ); + }} + /> + + + await adjustAnswer.onAdjustSQL(values.responseId, values.sql) + } + /> ); } diff --git a/wren-ui/src/utils/enum/dropdown.ts b/wren-ui/src/utils/enum/dropdown.ts new file mode 100644 index 0000000000..bd2956401e --- /dev/null +++ b/wren-ui/src/utils/enum/dropdown.ts @@ -0,0 +1,11 @@ +export enum MORE_ACTION { + EDIT = 'edit', + DELETE = 'delete', + UPDATE_COLUMNS = 'update_columns', + REFRESH = 'refresh', + HIDE_CATEGORY = 'hide_category', + VIEW_SQL_PAIR = 'view_sql_pair', + VIEW_INSTRUCTION = 'view_instruction', + ADJUST_SQL = 'adjust_sql', + ADJUST_STEPS = 'adjust_steps', +} diff --git a/wren-ui/src/utils/enum/index.ts b/wren-ui/src/utils/enum/index.ts index ba70eb1146..7d65d7b4fe 100644 --- a/wren-ui/src/utils/enum/index.ts +++ b/wren-ui/src/utils/enum/index.ts @@ -8,3 +8,4 @@ export * from './diagram'; export * from './home'; export * from './settings'; export * from './knowledge'; +export * from './dropdown'; diff --git a/wren-ui/src/utils/enum/modeling.ts b/wren-ui/src/utils/enum/modeling.ts index ea7d5ff923..dfdd7edf64 100644 --- a/wren-ui/src/utils/enum/modeling.ts +++ b/wren-ui/src/utils/enum/modeling.ts @@ -2,13 +2,3 @@ export { NodeType as NODE_TYPE, RelationType as JOIN_TYPE, } from '@/apollo/client/graphql/__types__'; - -export enum MORE_ACTION { - EDIT = 'edit', - DELETE = 'delete', - UPDATE_COLUMNS = 'update_columns', - REFRESH = 'refresh', - HIDE_CATEGORY = 'hide_category', - VIEW_SQL_PAIR = 'view_sql_pair', - VIEW_INSTRUCTION = 'view_instruction', -} diff --git a/wren-ui/src/utils/error/dictionary.ts b/wren-ui/src/utils/error/dictionary.ts index 7dad67daf5..1c264e299b 100644 --- a/wren-ui/src/utils/error/dictionary.ts +++ b/wren-ui/src/utils/error/dictionary.ts @@ -124,4 +124,13 @@ export const ERROR_TEXTS = { REQUIRED: 'Please input SQL statement.', }, }, + ADJUST_REASONING: { + SELECTED_MODELS: { + REQUIRED: 'Please select at least one model', + }, + STEPS: { + REQUIRED: 'Please input reasoning steps', + MAX_LENGTH: 'Reasoning steps must be 3000 characters or fewer.', + }, + }, }; diff --git a/wren-ui/src/utils/errorHandler.tsx b/wren-ui/src/utils/errorHandler.tsx index a1e692dae4..2c060d75fb 100644 --- a/wren-ui/src/utils/errorHandler.tsx +++ b/wren-ui/src/utils/errorHandler.tsx @@ -129,11 +129,11 @@ class GenerateThreadResponseAnswerErrorHandler extends ErrorHandler { } } -class GenerateThreadResponseBreakdownErrorHandler extends ErrorHandler { +class AdjustThreadResponseErrorHandler extends ErrorHandler { public getErrorMessage(error: GraphQLError) { switch (error.extensions?.code) { default: - return 'Failed to generate thread response breakdown SQL answer.'; + return 'Failed to adjust thread response answer.'; } } } @@ -376,8 +376,8 @@ errorHandlers.set( new GenerateThreadResponseAnswerErrorHandler(), ); errorHandlers.set( - 'GenerateThreadResponseBreakdown', - new GenerateThreadResponseBreakdownErrorHandler(), + 'AdjustThreadResponse', + new AdjustThreadResponseErrorHandler(), ); errorHandlers.set('CreateView', new CreateViewErrorHandler()); diff --git a/wren-ui/src/utils/svgs/EditSVG.tsx b/wren-ui/src/utils/svgs/EditSVG.tsx new file mode 100644 index 0000000000..c8bfb3278a --- /dev/null +++ b/wren-ui/src/utils/svgs/EditSVG.tsx @@ -0,0 +1,29 @@ +export const EditSVG = ({ + fillCurrentColor = true, + className, +}: { + fillCurrentColor?: boolean; + className?: string; +}) => ( + + + + +); diff --git a/wren-ui/src/utils/svgs/index.ts b/wren-ui/src/utils/svgs/index.ts index c9b2bb58b5..927cbdc6a8 100644 --- a/wren-ui/src/utils/svgs/index.ts +++ b/wren-ui/src/utils/svgs/index.ts @@ -1,3 +1,4 @@ export * from './CopilotSVG'; export * from './RobotSVG'; export * from './InstructionsSVG'; +export * from './EditSVG';