Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions wren-ui/migrations/20250102074256_create_sql_pair_table.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/**
* @param { import("knex").Knex } knex
* @returns { Promise<void> }
*/
exports.up = function (knex) {
return knex.schema.createTable('sql_pair', (table) => {
table.increments('id').primary();
table
.integer('project_id')
.notNullable()
.comment('Reference to project.id');
table.text('sql').notNullable();
table.string('question', 1000).notNullable();
table.timestamps(true, true);

table.foreign('project_id').references('id').inTable('project');
});
};

/**
* @param { import("knex").Knex } knex
* @returns { Promise<void> }
*/
exports.down = function (knex) {
return knex.schema.dropTable('sql_pair');
};
131 changes: 130 additions & 1 deletion wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ import {
ChartResult,
ChartStatus,
TextBasedAnswerStatus,
SqlPairResult,
SqlPairStatus,
QuestionInput,
QuestionsResult,
QuestionsStatus,
} from '@server/models/adaptor';
import { getLogger } from '@server/utils';
import * as Errors from '@server/utils/error';
import { SqlPair } from '../repositories';

const logger = getLogger('WrenAIAdaptor');
logger.level = 'debug';
Expand Down Expand Up @@ -87,6 +93,18 @@ export interface IWrenAIAdaptor {
adjustChart(input: ChartAdjustmentInput): Promise<AsyncQueryResponse>;
getChartAdjustmentResult(queryId: string): Promise<ChartResult>;
cancelChartAdjustment(queryId: string): Promise<void>;

/**
* Sql Pair APIs
*/
deploySqlPair(
projectId: number,
sqlPair: { question: string; sql: string },
): Promise<AsyncQueryResponse>;
getSqlPairResult(queryId: string): Promise<SqlPairResult>;
deleteSqlPairs(projectId: number, sqlPairIds: number[]): Promise<void>;
generateQuestions(input: QuestionInput): Promise<AsyncQueryResponse>;
getQuestionsResult(queryId: string): Promise<Partial<QuestionsResult>>;
}

export class WrenAIAdaptor implements IWrenAIAdaptor {
Expand All @@ -95,6 +113,70 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
constructor({ wrenAIBaseEndpoint }: { wrenAIBaseEndpoint: string }) {
this.wrenAIBaseEndpoint = wrenAIBaseEndpoint;
}
public async deploySqlPair(
projectId: number,
sqlPair: Partial<SqlPair>,
): Promise<AsyncQueryResponse> {
try {
const body = {
sql_pairs: [
{
id: `${sqlPair.id}`,
sql: sqlPair.sql,
question: sqlPair.question,
},
],
project_id: projectId.toString(),
};

return axios
.post(`${this.wrenAIBaseEndpoint}/v1/sql-pairs`, body)
.then((res) => {
return { queryId: res.data.event_id };
});
} catch (err: any) {
logger.debug(
`Got error when deploying SQL pair: ${getAIServiceError(err)}`,
);
throw err;
}
}
public async getSqlPairResult(queryId: string): Promise<SqlPairResult> {
try {
const res = await axios.get(
`${this.wrenAIBaseEndpoint}/v1/sql-pairs/${queryId}`,
);
const { status, error } = this.transformStatusAndError(res.data);
return {
status: status as SqlPairStatus,
error,
};
} catch (err: any) {
logger.debug(
`Got error when getting SQL pair result: ${getAIServiceError(err)}`,
);
throw err;
}
}
public async deleteSqlPairs(
projectId: number,
sqlPairIds: number[],
): Promise<void> {
try {
await axios.delete(`${this.wrenAIBaseEndpoint}/v1/sql-pairs`, {
data: {
sql_pair_ids: sqlPairIds.map((id) => id.toString()),
project_id: projectId.toString(),
},
});
return;
} catch (err: any) {
logger.debug(
`Got error when deleting SQL pair: ${getAIServiceError(err)}`,
);
throw err;
}
}

/**
* Ask AI service a question.
Expand Down Expand Up @@ -421,7 +503,49 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
throw err;
}
}
public async generateQuestions(
input: QuestionInput,
): Promise<AsyncQueryResponse> {
try {
const body = {
sqls: input.sqls,
project_id: input.projectId.toString(),
configuration: input.configurations,
};

const res = await axios.post(
`${this.wrenAIBaseEndpoint}/v1/sql-questions`,
body,
);
return { queryId: res.data.query_id };
} catch (err: any) {
logger.debug(
`Got error when generating questions: ${getAIServiceError(err)}`,
);
throw err;
}
}

public async getQuestionsResult(
queryId: string,
): Promise<Partial<QuestionsResult>> {
try {
const res = await axios.get(
`${this.wrenAIBaseEndpoint}/v1/sql-questions/${queryId}`,
);
const { status, error } = this.transformStatusAndError(res.data);
return {
status: status as QuestionsStatus,
error,
questions: res.data.questions || [],
};
} catch (err: any) {
logger.debug(
`Got error when getting questions result: ${getAIServiceError(err)}`,
);
throw err;
}
}
private transformChartAdjustmentInput(input: ChartAdjustmentInput) {
const { query, sql, adjustmentOption, chartSchema, configurations } = input;
return {
Expand Down Expand Up @@ -558,7 +682,12 @@ export class WrenAIAdaptor implements IWrenAIAdaptor {
}

private transformStatusAndError(body: any): {
status: AskResultStatus | TextBasedAnswerStatus | ChartStatus;
status:
| AskResultStatus
| TextBasedAnswerStatus
| ChartStatus
| SqlPairStatus
| QuestionsStatus;
error?: {
code: Errors.GeneralErrorCodes;
message: string;
Expand Down
44 changes: 37 additions & 7 deletions wren-ui/src/apollo/server/models/adaptor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export interface AskHistory {
steps: Array<AskStep>;
}

export interface AskConfigurations {
export interface ProjectConfigurations {
language?: string;
timezone?: { name: string };
}
Expand All @@ -61,7 +61,7 @@ export interface AskInput {
query: string;
deployId: string;
history?: AskHistory;
configurations?: AskConfigurations;
configurations?: ProjectConfigurations;
}

export interface AsyncQueryResponse {
Expand Down Expand Up @@ -102,7 +102,7 @@ export interface AskResponse<R, S> {
export interface AskDetailInput {
query: string;
sql: string;
configurations?: AskConfigurations;
configurations?: ProjectConfigurations;
}

export type AskDetailResult = AskResponse<
Expand Down Expand Up @@ -143,7 +143,7 @@ export type RecommendationQuestionsInput = {
maxCategories?: number;
regenerate?: boolean; // Optional regenerate questions (default: false)
// Optional configuration settings
configuration?: AskConfigurations;
configuration?: ProjectConfigurations;
};

export type RecommendationQuestion = {
Expand All @@ -166,7 +166,7 @@ export interface TextBasedAnswerInput {
sqlData: any;
threadId?: string;
userId?: string;
configurations?: AskConfigurations;
configurations?: ProjectConfigurations;
}

export enum TextBasedAnswerStatus {
Expand Down Expand Up @@ -203,7 +203,7 @@ export interface ChartInput {
query: string;
sql: string;
projectId?: string;
configurations?: AskConfigurations;
configurations?: ProjectConfigurations;
}

export interface ChartAdjustmentOption {
Expand All @@ -221,7 +221,7 @@ export interface ChartAdjustmentInput {
adjustmentOption: ChartAdjustmentOption;
chartSchema: Record<string, any>;
projectId?: string;
configurations?: AskConfigurations;
configurations?: ProjectConfigurations;
}

export interface ChartResponse {
Expand All @@ -235,3 +235,33 @@ export interface ChartResult {
response?: ChartResponse;
error?: WrenAIError;
}

export enum SqlPairStatus {
INDEXING = 'INDEXING',
FINISHED = 'FINISHED',
FAILED = 'FAILED',
}

export interface SqlPairResult {
status: SqlPairStatus;
error?: WrenAIError;
}

export interface QuestionInput {
sqls: string[];
projectId: number;
configurations?: ProjectConfigurations;
}

export enum QuestionsStatus {
GENERATING = 'GENERATING',
SUCCEEDED = 'SUCCEEDED',
FAILED = 'FAILED',
}

export interface QuestionsResult {
status: QuestionsStatus;
error?: WrenAIError;
questions?: string[];
trace_id?: string;
}
1 change: 1 addition & 0 deletions wren-ui/src/apollo/server/repositories/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ export * from './threadResponseRepository';
export * from './schemaChangeRepository';
export * from './dashboardRepository';
export * from './dashboardItemRepository';
export * from './sqlPairRepository';
20 changes: 20 additions & 0 deletions wren-ui/src/apollo/server/repositories/sqlPairRepository.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import { Knex } from 'knex';
import { BaseRepository, IBasicRepository } from './baseRepository';

export interface SqlPair {
id: number; // ID
projectId: number; // Reference to project.id
sql: string; // SQL query
question: string; // Natural language question
}

export interface ISqlPairRepository extends IBasicRepository<SqlPair> {}

export class SqlPairRepository
extends BaseRepository<SqlPair>
implements ISqlPairRepository
{
constructor(knexPg: Knex) {
super({ knexPg, tableName: 'sql_pair' });
}
}
11 changes: 11 additions & 0 deletions wren-ui/src/apollo/server/resolvers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { AskingResolver } from './resolvers/askingResolver';
import { DiagramResolver } from './resolvers/diagramResolver';
import { LearningResolver } from './resolvers/learningResolver';
import { DashboardResolver } from './resolvers/dashboardResolver';
import { SqlPairResolver } from './resolvers/sqlPairResolver';
import { convertColumnType } from '@server/utils';

const projectResolver = new ProjectResolver();
Expand All @@ -13,6 +14,7 @@ const askingResolver = new AskingResolver();
const diagramResolver = new DiagramResolver();
const learningResolver = new LearningResolver();
const dashboardResolver = new DashboardResolver();
const sqlPairResolver = new SqlPairResolver();
const resolvers = {
JSON: GraphQLJSON,
Query: {
Expand Down Expand Up @@ -55,6 +57,9 @@ const resolvers = {

// Dashboard
dashboardItems: dashboardResolver.getDashboardItems,

// SQL Pairs
sqlPairs: sqlPairResolver.getProjectSqlPairs,
},
Mutation: {
deploy: modelResolver.deploy,
Expand Down Expand Up @@ -137,6 +142,12 @@ const resolvers = {
createDashboardItem: dashboardResolver.createDashboardItem,
deleteDashboardItem: dashboardResolver.deleteDashboardItem,
previewItemSQL: dashboardResolver.previewItemSQL,

// SQL Pairs
createSqlPair: sqlPairResolver.createSqlPair,
updateSqlPair: sqlPairResolver.updateSqlPair,
deleteSqlPair: sqlPairResolver.deleteSqlPair,
generateQuestion: sqlPairResolver.generateQuestion,
},
ThreadResponse: askingResolver.getThreadResponseNestedResolver(),
DetailStep: askingResolver.getDetailStepNestedResolver(),
Expand Down
Loading