Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/**
* @param { import("knex").Knex } knex
* @returns { Promise<void> }
*/
exports.up = function (knex) {
return knex.schema.alterTable('thread_response', (table) => {
table
.jsonb('adjustment')
.nullable()
.comment(
'Adjustment data for thread response, including type and payload',
);
});
};

/**
* @param { import("knex").Knex } knex
* @returns { Promise<void> }
*/
exports.down = function (knex) {
return knex.schema.alterTable('thread_response', (table) => {
table.dropColumn('adjustment');
});
};
55 changes: 54 additions & 1 deletion wren-ui/src/apollo/client/graphql/__types__.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ export type AdjustThreadResponseChartInput = {
yAxis?: InputMaybe<Scalars['String']>;
};

export type AdjustThreadResponseInput = {
sql?: InputMaybe<Scalars['String']>;
sqlGenerationReasoning?: InputMaybe<Scalars['String']>;
tables?: InputMaybe<Array<Scalars['String']>>;
};

export type AdjustmentTask = {
__typename?: 'AdjustmentTask';
error?: Maybe<Error>;
queryId?: Maybe<Scalars['String']>;
sql?: Maybe<Scalars['String']>;
status?: Maybe<AskingTaskStatus>;
traceId?: Maybe<Scalars['String']>;
};

export type AskingTask = {
__typename?: 'AskingTask';
candidates: Array<ResultCandidate>;
Expand Down Expand Up @@ -547,7 +562,9 @@ export type ModelWhereInput = {

export type Mutation = {
__typename?: 'Mutation';
adjustThreadResponse: ThreadResponse;
adjustThreadResponseChart: ThreadResponse;
cancelAdjustmentTask: Scalars['Boolean'];
cancelAskingTask: Scalars['Boolean'];
createAskingTask: Task;
createCalculatedField: Scalars['JSON'];
Expand Down Expand Up @@ -581,6 +598,7 @@ export type Mutation = {
previewModelData: Scalars['JSON'];
previewSql: Scalars['JSON'];
previewViewData: Scalars['JSON'];
rerunAdjustmentTask: Scalars['Boolean'];
rerunAskingTask: Task;
resetCurrentProject: Scalars['Boolean'];
resolveSchemaChange: Scalars['Boolean'];
Expand All @@ -607,12 +625,23 @@ export type Mutation = {
};


export type MutationAdjustThreadResponseArgs = {
data: AdjustThreadResponseInput;
responseId: Scalars['Int'];
};


export type MutationAdjustThreadResponseChartArgs = {
data: AdjustThreadResponseChartInput;
responseId: Scalars['Int'];
};


export type MutationCancelAdjustmentTaskArgs = {
taskId: Scalars['String'];
};


export type MutationCancelAskingTaskArgs = {
taskId: Scalars['String'];
};
Expand Down Expand Up @@ -774,6 +803,11 @@ export type MutationPreviewViewDataArgs = {
};


export type MutationRerunAdjustmentTaskArgs = {
responseId: Scalars['Int'];
};


export type MutationRerunAskingTaskArgs = {
responseId: Scalars['Int'];
};
Expand Down Expand Up @@ -933,7 +967,7 @@ export type PreviewItemSqlInput = {
export type PreviewSqlDataInput = {
dryRun?: InputMaybe<Scalars['Boolean']>;
limit?: InputMaybe<Scalars['Int']>;
projectId?: InputMaybe<Scalars['Int']>;
projectId?: InputMaybe<Scalars['String']>;
sql: Scalars['String'];
};

Expand All @@ -957,6 +991,7 @@ export enum ProjectLanguage {

export type Query = {
__typename?: 'Query';
adjustmentTask?: Maybe<AdjustmentTask>;
askingTask?: Maybe<AskingTask>;
autoGenerateRelation: Array<RecommendRelations>;
dashboardItems: Array<DashboardItem>;
Expand Down Expand Up @@ -985,6 +1020,11 @@ export type Query = {
};


export type QueryAdjustmentTaskArgs = {
taskId: Scalars['String'];
};


export type QueryAskingTaskArgs = {
taskId: Scalars['String'];
};
Expand Down Expand Up @@ -1200,6 +1240,8 @@ export type Thread = {

export type ThreadResponse = {
__typename?: 'ThreadResponse';
adjustment?: Maybe<ThreadResponseAdjustment>;
adjustmentTask?: Maybe<AdjustmentTask>;
answerDetail?: Maybe<ThreadResponseAnswerDetail>;
askingTask?: Maybe<AskingTask>;
breakdownDetail?: Maybe<ThreadResponseBreakdownDetail>;
Expand All @@ -1211,6 +1253,17 @@ export type ThreadResponse = {
view?: Maybe<ViewInfo>;
};

export type ThreadResponseAdjustment = {
__typename?: 'ThreadResponseAdjustment';
payload?: Maybe<Scalars['JSON']>;
type: ThreadResponseAdjustmentType;
};

export enum ThreadResponseAdjustmentType {
APPLY_SQL = 'APPLY_SQL',
REASONING = 'REASONING'
}

export type ThreadResponseAnswerDetail = {
__typename?: 'ThreadResponseAnswerDetail';
content?: Maybe<Scalars['String']>;
Expand Down
241 changes: 192 additions & 49 deletions wren-ui/src/apollo/client/graphql/home.generated.ts

Large diffs are not rendered by default.

66 changes: 56 additions & 10 deletions wren-ui/src/apollo/client/graphql/home.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,26 @@ const COMMON_RESPONSE = gql`
askingTask {
...CommonAskingTask
}
adjustment {
type
payload
}
adjustmentTask {
queryId
status
error {
...CommonError
}
sql
traceId
}
}

${COMMON_BREAKDOWN_DETAIL}
${COMMON_ANSWER_DETAIL}
${COMMON_CHART_DETAIL}
${COMMON_ASKING_TASK}
${COMMON_ERROR}
`;

const COMMON_RECOMMENDED_QUESTIONS_TASK = gql`
Expand Down Expand Up @@ -253,6 +267,19 @@ export const UPDATE_THREAD_RESPONSE = gql`
${COMMON_RESPONSE}
`;

// For adjust reasoning steps or SQL
export const ADJUST_THREAD_RESPONSE = gql`
mutation AdjustThreadResponse(
$responseId: Int!
$data: AdjustThreadResponseInput!
) {
adjustThreadResponse(responseId: $responseId, data: $data) {
...CommonResponse
}
}
${COMMON_RESPONSE}
`;

export const DELETE_THREAD = gql`
mutation DeleteThread($where: ThreadUniqueWhereInput!) {
deleteThread(where: $where)
Expand Down Expand Up @@ -329,16 +356,6 @@ export const GENERATE_THREAD_RECOMMENDATION_QUESTIONS = gql`
}
`;

export const GENERATE_THREAD_RESPONSE_BREAKDOWN = gql`
mutation GenerateThreadResponseBreakdown($responseId: Int!) {
generateThreadResponseBreakdown(responseId: $responseId) {
...CommonResponse
}
}

${COMMON_RESPONSE}
`;

export const GENERATE_THREAD_RESPONSE_ANSWER = gql`
mutation GenerateThreadResponseAnswer($responseId: Int!) {
generateThreadResponseAnswer(responseId: $responseId) {
Expand Down Expand Up @@ -369,3 +386,32 @@ export const ADJUST_THREAD_RESPONSE_CHART = gql`
}
${COMMON_RESPONSE}
`;

export const ADJUSTMENT_TASK = gql`
query AdjustmentTask($taskId: String!) {
adjustmentTask(taskId: $taskId) {
queryId
status
error {
code
shortMessage
message
stacktrace
}
sql
traceId
}
}
`;

export const CANCEL_ADJUSTMENT_TASK = gql`
mutation CancelAdjustmentTask($taskId: String!) {
cancelAdjustmentTask(taskId: $taskId)
}
`;

export const RERUN_ADJUSTMENT_TASK = gql`
mutation RerunAdjustmentTask($responseId: Int!) {
rerunAdjustmentTask(responseId: $responseId)
}
`;
84 changes: 83 additions & 1 deletion wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ import {
GenerateInstructionInput,
InstructionStatus,
InstructionResult,
AskFeedbackInput,
AskFeedbackResult,
AskFeedbackStatus,
} from '@server/models/adaptor';
import { getLogger } from '@server/utils';
import * as Errors from '@server/utils/error';
Expand Down Expand Up @@ -119,6 +122,13 @@ export interface IWrenAIAdaptor {
): Promise<AsyncQueryResponse>;
getInstructionResult(queryId: string): Promise<InstructionResult>;
deleteInstructions(ids: number[], projectId: number): Promise<void>;

/**
* Ask feedback APIs
*/
createAskFeedback(input: AskFeedbackInput): Promise<AsyncQueryResponse>;
getAskFeedbackResult(queryId: string): Promise<AskFeedbackResult>;
cancelAskFeedback(queryId: string): Promise<void>;
}

export class WrenAIAdaptor implements IWrenAIAdaptor {
Expand Down Expand Up @@ -647,6 +657,77 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
}
}

public async createAskFeedback(
input: AskFeedbackInput,
): Promise<AsyncQueryResponse> {
try {
const body = {
tables: input.tables,
sql_generation_reasoning: input.sqlGenerationReasoning,
sql: input.sql,
project_id: input.projectId.toString(),
configurations: input.configurations,
};

const res = await axios.post(
`${this.wrenAIBaseEndpoint}/v1/ask-feedbacks`,
body,
);
return { queryId: res.data.query_id };
} catch (err: any) {
logger.debug(
`Got error when creating ask feedback: ${getAIServiceError(err)}`,
);
throw err;
}
}

public async getAskFeedbackResult(
queryId: string,
): Promise<AskFeedbackResult> {
try {
const res = await axios.get(
`${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`,
);
return this.transformAskFeedbackResult(res.data);
} catch (err: any) {
logger.debug(
`Got error when getting ask feedback result: ${getAIServiceError(err)}`,
);
throw err;
}
}

public async cancelAskFeedback(queryId: string): Promise<void> {
try {
await axios.patch(
`${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`,
{
status: 'stopped',
},
);
} catch (err: any) {
logger.debug(
`Got error when canceling ask feedback: ${getAIServiceError(err)}`,
);
throw err;
}
}

private transformAskFeedbackResult(body: any): AskFeedbackResult {
const { status, error } = this.transformStatusAndError(body);
return {
status: status as AskFeedbackStatus,
error,
response:
body.response?.map((result: any) => ({
sql: result.sql,
type: result.type?.toUpperCase() as AskCandidateType,
})) || [],
traceId: body.trace_id,
};
}

private transformChartAdjustmentInput(input: ChartAdjustmentInput) {
const { query, sql, adjustmentOption, chartSchema, configurations } = input;
return {
Expand Down Expand Up @@ -794,7 +875,8 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
| ChartStatus
| SqlPairStatus
| QuestionsStatus
| InstructionStatus;
| InstructionStatus
| AskFeedbackStatus;
error?: {
code: Errors.GeneralErrorCodes;
message: string;
Expand Down
Loading