diff --git a/.github/workflows/release-rc-pr.yaml b/.github/workflows/create-rc-release-pr.yaml similarity index 53% rename from .github/workflows/release-rc-pr.yaml rename to .github/workflows/create-rc-release-pr.yaml index 9bdb5bee82..b1bb8b68bc 100644 --- a/.github/workflows/release-rc-pr.yaml +++ b/.github/workflows/create-rc-release-pr.yaml @@ -1,4 +1,4 @@ -name: Create Release RC +name: Create RC Release PR on: workflow_dispatch: @@ -32,12 +32,6 @@ jobs: git config --global user.name "wren-ai[bot]" git config --global user.email "dev@cannerdata.com" - - name: Create branch - run: | - BRANCH_NAME="release/${{ github.event.inputs.release_version }}" - git checkout -b $BRANCH_NAME - echo "BRANCH_NAME=$BRANCH_NAME" >> $GITHUB_ENV - - name: Update docker.go run: | FILE_PATH="wren-launcher/utils/docker.go" @@ -70,51 +64,12 @@ jobs: echo "===== Git diff for changed files =====" git diff - - name: Commit changes - run: | - git add wren-launcher/utils/docker.go docker/.env.example - git commit -m "release ${{ github.event.inputs.release_version }}" || echo "No changes to commit" - git status - - - name: Push branch - run: | - git push --set-upstream origin ${{ env.BRANCH_NAME }} - echo "Branch pushed to origin/${{ env.BRANCH_NAME }}" - - - name: Install GitHub CLI - run: | - type -p curl >/dev/null || (apt update && apt install curl -y) - curl -fsSL https://cli.github.com/packages/githubcli-archive-keyring.gpg | dd of=/usr/share/keyrings/githubcli-archive-keyring.gpg \ - && chmod go+r /usr/share/keyrings/githubcli-archive-keyring.gpg \ - && echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/githubcli-archive-keyring.gpg] https://cli.github.com/packages stable main" | tee /etc/apt/sources.list.d/github-cli.list > /dev/null \ - && apt update \ - && apt install gh -y - - - name: Create Pull Request with GitHub CLI - run: | - # Check if branch has changes compared to main - if git diff --quiet origin/main origin/${{ env.BRANCH_NAME }}; then - echo "No changes detected between main and ${{ env.BRANCH_NAME }}. Cannot create PR." - exit 1 - fi - - # Create PR using gh cli - gh pr create \ - --title "Release ${{ github.event.inputs.release_version }}" \ - --body "Release ${{ github.event.inputs.release_version }}" \ - --base main \ - --head ${{ env.BRANCH_NAME }} - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - - name: Fallback PR creation - if: failure() - uses: peter-evans/create-pull-request@v5 + - name: Create PR + uses: peter-evans/create-pull-request@v7 with: - token: ${{ secrets.GITHUB_TOKEN }} - commit-message: "release ${{ github.event.inputs.release_version }}" - branch: ${{ env.BRANCH_NAME }} base: main + branch: "release/${{ github.event.inputs.release_version }}" + commit-message: "release ${{ github.event.inputs.release_version }}" title: "Release ${{ github.event.inputs.release_version }}" body: "Release ${{ github.event.inputs.release_version }}" draft: false diff --git a/.github/workflows/create-rc-release.yaml b/.github/workflows/create-rc-release.yaml new file mode 100644 index 0000000000..6e59aeba3c --- /dev/null +++ b/.github/workflows/create-rc-release.yaml @@ -0,0 +1,81 @@ +name: Create RC Release + +on: + workflow_dispatch: + inputs: + release_version: + description: "Release version" + required: true + issue_comment: + types: [created] + +jobs: + release: + runs-on: macos-latest + if: ${{ github.event_name == 'issue_comment' && contains(github.event.comment.body, '/release-rc') && startsWith(github.event.issue.title, 'Release') }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + ref: main + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: "1.23.0" + + - name: Add rocket emoji to comment + run: | + curl -X POST -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ + -d '{"body": "🚀 Starting the release process!"}' \ + "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/comments" + + - name: Parse release version from PR title + id: parse_release_version + env: + GITHUB_ISSUE_TITLE: ${{ github.event.issue.title }} + run: | + release_version=$(echo "$GITHUB_ISSUE_TITLE" | sed 's/ /\n/g' | tail -n 1) + echo "Release version: $release_version" + echo "release_version=$release_version" >> $GITHUB_OUTPUT + + - name: Build for macOS + working-directory: ./wren-launcher + run: | + mkdir -p dist + env GOARCH=amd64 GOOS=darwin CGO_ENABLED=1 go build -o dist/wren-launcher-darwin main.go + cd dist && chmod +x wren-launcher-darwin && tar zcvf wren-launcher-darwin.tar.gz wren-launcher-darwin + + - name: Build for Linux + working-directory: ./wren-launcher + run: | + mkdir -p dist + env GOARCH=amd64 GOOS=linux CGO_ENABLED=0 go build -o dist/wren-launcher-linux main.go + cd dist && chmod +x wren-launcher-linux && tar zcvf wren-launcher-linux.tar.gz wren-launcher-linux + + - name: Build for Windows + working-directory: ./wren-launcher + run: | + mkdir -p dist + env GOARCH=amd64 GOOS=windows CGO_ENABLED=0 go build -o dist/wren-launcher-windows.exe main.go + cd dist && zip wren-launcher-windows.zip wren-launcher-windows.exe + + - name: Create GitHub Release + uses: softprops/action-gh-release@v2 + with: + files: | + ./wren-launcher/dist/wren-launcher-darwin.tar.gz + ./wren-launcher/dist/wren-launcher-linux.tar.gz + ./wren-launcher/dist/wren-launcher-windows.zip + tag_name: ${{ steps.parse_release_version.outputs.release_version }} + prerelease: true + generate_release_notes: true + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: Comment with release link + run: | + release_url="https://github.com/${{ github.repository }}/releases/tag/${{ steps.parse_release_version.outputs.release_version }}" + curl -X POST -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \ + -d "{\"body\": \"🚀 A new release has been created! [View Release](${release_url})\"}" \ + "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.issue.number }}/comments" diff --git a/.github/workflows/ui-release-image-stable.yaml b/.github/workflows/ui-release-image-stable.yaml index f22014f38d..eebcc50bf3 100644 --- a/.github/workflows/ui-release-image-stable.yaml +++ b/.github/workflows/ui-release-image-stable.yaml @@ -87,8 +87,9 @@ jobs: password: ${{ secrets.GITHUB_TOKEN }} - name: Create manifest list and push working-directory: /tmp/digests + # tag with input version and latest run: | - TAGS=$(echo "${{ steps.meta.outputs.tags }}" | awk '{printf "--tag %s ", $0}') + TAGS=$(echo "${{ steps.meta.outputs.tags }} --tag latest" | awk '{printf "--tag %s ", $0}') docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \ $(printf '${{ env.WREN_UI_IMAGE }}@sha256:%s ' *) \ $TAGS diff --git a/.gitignore b/.gitignore index f62facef6c..741105ee4c 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ wren-ai-service/tools/dev/etc/** .deepeval-cache.json .deepeval_telemtry.txt docker/config.yaml +docker/docker-compose-local.yaml # python .python-version diff --git a/docker/.env.example b/docker/.env.example index 7cc5971dbd..9aa4b0dc01 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -20,11 +20,11 @@ OPENAI_API_KEY= # version # CHANGE THIS TO THE LATEST VERSION -WREN_PRODUCT_VERSION=0.18.0-rc.1 +WREN_PRODUCT_VERSION=0.18.0 WREN_ENGINE_VERSION=0.14.8 -WREN_AI_SERVICE_VERSION=0.18.0 +WREN_AI_SERVICE_VERSION=0.18.2 IBIS_SERVER_VERSION=0.14.8 -WREN_UI_VERSION=0.23.2 +WREN_UI_VERSION=0.23.4 WREN_BOOTSTRAP_VERSION=0.1.5 # user id (uuid v4) diff --git a/wren-ai-service/pyproject.toml b/wren-ai-service/pyproject.toml index 68406cad3e..d37693c88d 100644 --- a/wren-ai-service/pyproject.toml +++ b/wren-ai-service/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wren-ai-service" -version = "0.18.0" +version = "0.18.2" description = "" authors = ["dev@getwren.ai"] license = "AGPL-3.0" diff --git a/wren-ai-service/src/pipelines/indexing/instructions.py b/wren-ai-service/src/pipelines/indexing/instructions.py index bcb417f675..83a3f70707 100644 --- a/wren-ai-service/src/pipelines/indexing/instructions.py +++ b/wren-ai-service/src/pipelines/indexing/instructions.py @@ -43,7 +43,9 @@ def run(self, instructions: list[Instruction], project_id: Optional[str] = ""): "is_default": instruction.is_default, **addition, }, - content=instruction.question, + content="this is a global instruction, so no question is provided" + if instruction.is_default + else instruction.question, ) for instruction in instructions ] @@ -164,12 +166,12 @@ async def run( @observe(name="Clean Documents for Instructions") async def clean( self, - instructions: List[Instruction], + instructions: Optional[List[Instruction]] = None, project_id: Optional[str] = None, delete_all: bool = False, ) -> None: await clean( - instructions=instructions, + instructions=instructions or [], cleaner=self._components["cleaner"], project_id=project_id, delete_all=delete_all, diff --git a/wren-ai-service/src/web/v1/services/instructions.py b/wren-ai-service/src/web/v1/services/instructions.py index 393175f03f..cb6d17442d 100644 --- a/wren-ai-service/src/web/v1/services/instructions.py +++ b/wren-ai-service/src/web/v1/services/instructions.py @@ -74,16 +74,27 @@ async def index( trace_id = kwargs.get("trace_id") try: - instructions = [ - Instruction( - id=instruction.id, - instruction=instruction.instruction, - question=question, - is_default=instruction.is_default, - ) - for instruction in request.instructions - for question in instruction.questions - ] + instructions = [] + for instruction in request.instructions: + if instruction.is_default: + instructions.append( + Instruction( + id=instruction.id, + instruction=instruction.instruction, + question="", + is_default=True, + ) + ) + else: + for question in instruction.questions: + instructions.append( + Instruction( + id=instruction.id, + instruction=instruction.instruction, + question=question, + is_default=False, + ) + ) await self._pipelines["instructions_indexing"].run( project_id=request.project_id, diff --git a/wren-launcher/Makefile b/wren-launcher/Makefile index 284e0c37c9..43738b515e 100644 --- a/wren-launcher/Makefile +++ b/wren-launcher/Makefile @@ -2,8 +2,8 @@ BINARY_NAME=wren-launcher build: env GOARCH=amd64 GOOS=darwin CGO_ENABLED=1 go build -o dist/${BINARY_NAME}-darwin main.go - env GOARCH=amd64 GOOS=linux go build -o dist/${BINARY_NAME}-linux main.go - env GOARCH=amd64 GOOS=windows go build -o dist/${BINARY_NAME}-windows.exe main.go + env GOARCH=amd64 GOOS=linux CGO_ENABLED=0 go build -o dist/${BINARY_NAME}-linux main.go + env GOARCH=amd64 GOOS=windows CGO_ENABLED=0 go build -o dist/${BINARY_NAME}-windows.exe main.go cd ./dist; chmod +x ${BINARY_NAME}-darwin && chmod +x ${BINARY_NAME}-linux \ && tar zcvf ${BINARY_NAME}-darwin.tar.gz ${BINARY_NAME}-darwin \ && tar zcvf ${BINARY_NAME}-linux.tar.gz ${BINARY_NAME}-linux \ diff --git a/wren-launcher/utils/docker.go b/wren-launcher/utils/docker.go index 3fbdc08310..c5931b271b 100644 --- a/wren-launcher/utils/docker.go +++ b/wren-launcher/utils/docker.go @@ -24,7 +24,7 @@ import ( const ( // please change the version when the version is updated - WREN_PRODUCT_VERSION string = "0.18.0-rc.1" + WREN_PRODUCT_VERSION string = "0.18.0" DOCKER_COMPOSE_YAML_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/docker-compose.yaml" DOCKER_COMPOSE_ENV_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/.env.example" AI_SERVICE_CONFIG_URL string = "https://raw.githubusercontent.com/Canner/WrenAI/" + WREN_PRODUCT_VERSION + "/docker/config.example.yaml" diff --git a/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js b/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js new file mode 100644 index 0000000000..5c06d6ddf5 --- /dev/null +++ b/wren-ui/migrations/20250510000000_add_adjustment_to_thread_response.js @@ -0,0 +1,24 @@ +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.up = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table + .jsonb('adjustment') + .nullable() + .comment( + 'Adjustment data for thread response, including type and payload', + ); + }); +}; + +/** + * @param { import("knex").Knex } knex + * @returns { Promise } + */ +exports.down = function (knex) { + return knex.schema.alterTable('thread_response', (table) => { + table.dropColumn('adjustment'); + }); +}; diff --git a/wren-ui/package.json b/wren-ui/package.json index 48893af357..6382387cde 100644 --- a/wren-ui/package.json +++ b/wren-ui/package.json @@ -1,6 +1,6 @@ { "name": "wren-ui", - "version": "0.23.2", + "version": "0.23.3", "private": true, "scripts": { "dev": "next dev", diff --git a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts index f81449223e..e0caa6503b 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenAIAdaptor.ts @@ -30,6 +30,9 @@ import { GenerateInstructionInput, InstructionStatus, InstructionResult, + AskFeedbackInput, + AskFeedbackResult, + AskFeedbackStatus, } from '@server/models/adaptor'; import { getLogger } from '@server/utils'; import * as Errors from '@server/utils/error'; @@ -48,6 +51,7 @@ const getAIServiceError = (error: any) => { export interface IWrenAIAdaptor { deploy(deployData: DeployData): Promise; + delete(projectId: number): Promise; /** * Ask AI service a question. @@ -118,6 +122,13 @@ export interface IWrenAIAdaptor { ): Promise; getInstructionResult(queryId: string): Promise; deleteInstructions(ids: number[], projectId: number): Promise; + + /** + * Ask feedback APIs + */ + createAskFeedback(input: AskFeedbackInput): Promise; + getAskFeedbackResult(queryId: string): Promise; + cancelAskFeedback(queryId: string): Promise; } export class WrenAIAdaptor implements IWrenAIAdaptor { @@ -126,6 +137,31 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { constructor({ wrenAIBaseEndpoint }: { wrenAIBaseEndpoint: string }) { this.wrenAIBaseEndpoint = wrenAIBaseEndpoint; } + + public async delete(projectId: number): Promise { + try { + if (!projectId) { + throw new Error('Project ID is required'); + } + const url = `${this.wrenAIBaseEndpoint}/v1/semantics`; + const response = await axios.delete(url, { + params: { + project_id: projectId.toString(), + }, + }); + + if (response.status === 200) { + logger.info(`Wren AI: Deleted semantics for project ${projectId}`); + } else { + throw new Error(`Failed to delete semantics. ${response.data?.error}`); + } + } catch (error: any) { + throw new Error( + `Wren AI: Failed to delete semantics: ${getAIServiceError(error)}`, + ); + } + } + public async deploySqlPair( projectId: number, sqlPair: Partial, @@ -542,19 +578,19 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { public async generateInstruction( input: GenerateInstructionInput[], ): Promise { - const body = input.map((item) => ({ - id: item.id.toString(), - instruction: item.instruction, - questions: item.questions, - is_default: item.isDefault, - project_id: item.projectId?.toString(), - })); + const body = { + instructions: input.map((item) => ({ + id: item.id.toString(), + instruction: item.instruction, + questions: item.questions, + is_default: item.isDefault, + })), + project_id: input[0]?.projectId.toString(), + }; try { const res = await axios.post( `${this.wrenAIBaseEndpoint}/v1/instructions`, - { - instructions: body, - }, + body, ); return { queryId: res.data.event_id }; } catch (err: any) { @@ -621,6 +657,77 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { } } + public async createAskFeedback( + input: AskFeedbackInput, + ): Promise { + try { + const body = { + tables: input.tables, + sql_generation_reasoning: input.sqlGenerationReasoning, + sql: input.sql, + project_id: input.projectId.toString(), + configurations: input.configurations, + }; + + const res = await axios.post( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks`, + body, + ); + return { queryId: res.data.query_id }; + } catch (err: any) { + logger.debug( + `Got error when creating ask feedback: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async getAskFeedbackResult( + queryId: string, + ): Promise { + try { + const res = await axios.get( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`, + ); + return this.transformAskFeedbackResult(res.data); + } catch (err: any) { + logger.debug( + `Got error when getting ask feedback result: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + public async cancelAskFeedback(queryId: string): Promise { + try { + await axios.patch( + `${this.wrenAIBaseEndpoint}/v1/ask-feedbacks/${queryId}`, + { + status: 'stopped', + }, + ); + } catch (err: any) { + logger.debug( + `Got error when canceling ask feedback: ${getAIServiceError(err)}`, + ); + throw err; + } + } + + private transformAskFeedbackResult(body: any): AskFeedbackResult { + const { status, error } = this.transformStatusAndError(body); + return { + status: status as AskFeedbackStatus, + error, + response: + body.response?.map((result: any) => ({ + sql: result.sql, + type: result.type?.toUpperCase() as AskCandidateType, + })) || [], + traceId: body.trace_id, + }; + } + private transformChartAdjustmentInput(input: ChartAdjustmentInput) { const { query, sql, adjustmentOption, chartSchema, configurations } = input; return { @@ -768,7 +875,8 @@ export class WrenAIAdaptor implements IWrenAIAdaptor { | ChartStatus | SqlPairStatus | QuestionsStatus - | InstructionStatus; + | InstructionStatus + | AskFeedbackStatus; error?: { code: Errors.GeneralErrorCodes; message: string; diff --git a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts index c946c9b7a0..f849cffb88 100644 --- a/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts +++ b/wren-ui/src/apollo/server/adaptors/wrenEngineAdaptor.ts @@ -200,7 +200,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { return res.data as EngineQueryResponse; } catch (err: any) { logger.debug(`Got error when querying duckdb: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } @@ -219,7 +222,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { await axios.patch(url.href, configPayload, { headers }); } catch (err: any) { logger.debug(`Got error when patching config: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } @@ -248,7 +254,10 @@ export class WrenEngineAdaptor implements IWrenEngineAdaptor { return res.data; } catch (err: any) { logger.debug(`Got error when previewing data: ${err.message}`); - throw err; + throw Errors.create(Errors.GeneralErrorCodes.WREN_ENGINE_ERROR, { + customMessage: err.response?.data?.message || err.message, + originalError: err, + }); } } diff --git a/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts b/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts new file mode 100644 index 0000000000..6db81fa906 --- /dev/null +++ b/wren-ui/src/apollo/server/backgrounds/adjustmentBackgroundTracker.ts @@ -0,0 +1,502 @@ +import { getLogger } from '@server/utils'; +import { + AskFeedbackInput, + AskFeedbackResult, + AskFeedbackStatus, +} from '@server/models/adaptor'; +import { + AskingTask, + IAskingTaskRepository, + IThreadResponseRepository, + ThreadResponse, + ThreadResponseAdjustmentType, +} from '@server/repositories'; +import { IWrenAIAdaptor } from '../adaptors'; +import { TelemetryEvent, WrenService } from '../telemetry/telemetry'; +import { PostHogTelemetry } from '../telemetry/telemetry'; + +const logger = getLogger('AdjustmentTaskTracker'); +logger.level = 'debug'; + +interface TrackedTask { + queryId: string; + taskId?: number; + lastPolled: number; + result?: AskFeedbackResult; + isFinalized: boolean; + threadResponseId: number; + question: string; + originalThreadResponseId: number; + rerun?: boolean; + adjustmentPayload?: { + originalThreadResponseId: number; + retrievedTables: string[]; + sqlGenerationReasoning: string; + }; +} + +export type TrackedAdjustmentResult = AskFeedbackResult & { + taskId?: number; + queryId: string; +}; + +export type CreateAdjustmentTaskInput = AskFeedbackInput & { + threadId: number; + question: string; + originalThreadResponseId: number; + configurations: { language: string }; +}; + +export type RerunAdjustmentTaskInput = { + threadResponseId: number; + threadId: number; + projectId: number; + configurations: { language: string }; +}; + +export interface IAdjustmentBackgroundTaskTracker { + createAdjustmentTask( + input: CreateAdjustmentTaskInput, + ): Promise<{ queryId: string }>; + getAdjustmentResult(queryId: string): Promise; + getAdjustmentResultById(id: number): Promise; + cancelAdjustmentTask(queryId: string): Promise; + rerunAdjustmentTask( + input: RerunAdjustmentTaskInput, + ): Promise<{ queryId: string }>; +} + +export class AdjustmentBackgroundTaskTracker + implements IAdjustmentBackgroundTaskTracker +{ + private wrenAIAdaptor: IWrenAIAdaptor; + private askingTaskRepository: IAskingTaskRepository; + private trackedTasks: Map = new Map(); + private trackedTasksById: Map = new Map(); + private pollingInterval: number; + private memoryRetentionTime: number; + private pollingIntervalId: NodeJS.Timeout; + private runningJobs = new Set(); + private threadResponseRepository: IThreadResponseRepository; + private telemetry: PostHogTelemetry; + + constructor({ + telemetry, + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + pollingInterval = 1000, // 1 second + memoryRetentionTime = 5 * 60 * 1000, // 5 minutes + }: { + telemetry: PostHogTelemetry; + wrenAIAdaptor: IWrenAIAdaptor; + askingTaskRepository: IAskingTaskRepository; + threadResponseRepository: IThreadResponseRepository; + pollingInterval?: number; + memoryRetentionTime?: number; + }) { + this.telemetry = telemetry; + this.wrenAIAdaptor = wrenAIAdaptor; + this.askingTaskRepository = askingTaskRepository; + this.threadResponseRepository = threadResponseRepository; + this.pollingInterval = pollingInterval; + this.memoryRetentionTime = memoryRetentionTime; + this.startPolling(); + } + + public async createAdjustmentTask( + input: CreateAdjustmentTaskInput, + ): Promise<{ queryId: string; createdThreadResponse: ThreadResponse }> { + try { + // Call the AI service to create a task + const response = await this.wrenAIAdaptor.createAskFeedback(input); + const queryId = response.queryId; + + // create a new asking task + const createdAskingTask = await this.askingTaskRepository.createOne({ + queryId, + question: input.question, + threadId: input.threadId, + detail: { + adjustment: true, + status: AskFeedbackStatus.UNDERSTANDING, + response: [], + error: null, + }, + }); + + // create a new thread response with adjustment payload + const createdThreadResponse = + await this.threadResponseRepository.createOne({ + question: input.question, + threadId: input.threadId, + askingTaskId: createdAskingTask.id, + adjustment: { + type: ThreadResponseAdjustmentType.REASONING, + payload: { + originalThreadResponseId: input.originalThreadResponseId, + retrievedTables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + }, + }, + }); + + // bind the thread response to the asking task + // todo: it's weird that we need to update the asking task again + // find a better way to do this + await this.askingTaskRepository.updateOne(createdAskingTask.id, { + threadResponseId: createdThreadResponse.id, + }); + + // Start tracking this task + const task = { + queryId, + lastPolled: Date.now(), + isFinalized: false, + originalThreadResponseId: input.originalThreadResponseId, + threadResponseId: createdThreadResponse.id, + question: input.question, + adjustmentPayload: { + originalThreadResponseId: input.originalThreadResponseId, + retrievedTables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + }, + } as TrackedTask; + this.trackedTasks.set(queryId, task); + this.trackedTasksById.set(createdThreadResponse.id, task); + + logger.info(`Created adjustment task with queryId: ${queryId}`); + return { queryId, createdThreadResponse }; + } catch (err) { + logger.error(`Failed to create adjustment task: ${err}`); + throw err; + } + } + + public async rerunAdjustmentTask( + input: RerunAdjustmentTaskInput, + ): Promise<{ queryId: string }> { + const currentThreadResponse = await this.threadResponseRepository.findOneBy( + { + id: input.threadResponseId, + }, + ); + if (!currentThreadResponse) { + throw new Error(`Thread response ${input.threadResponseId} not found`); + } + + const adjustment = currentThreadResponse.adjustment; + if (!adjustment) { + throw new Error( + `Thread response ${input.threadResponseId} has no adjustment`, + ); + } + + const originalThreadResponse = + await this.threadResponseRepository.findOneBy({ + id: adjustment.payload?.originalThreadResponseId, + }); + if (!originalThreadResponse) { + throw new Error( + `Original thread response ${adjustment.payload?.originalThreadResponseId} not found`, + ); + } + + // call createAskFeedback on AI service + const response = await this.wrenAIAdaptor.createAskFeedback({ + ...input, + tables: adjustment.payload?.retrievedTables, + sqlGenerationReasoning: adjustment.payload?.sqlGenerationReasoning, + sql: originalThreadResponse.sql, + }); + const queryId = response.queryId; + + // update asking task with new queryId + await this.askingTaskRepository.updateOne( + currentThreadResponse.askingTaskId, + { + queryId, + + // reset detail + detail: { + adjustment: true, + status: AskFeedbackStatus.UNDERSTANDING, + response: [], + error: null, + }, + }, + ); + + // schedule task + const task = { + queryId, + lastPolled: Date.now(), + isFinalized: false, + originalThreadResponseId: originalThreadResponse.id, + threadResponseId: currentThreadResponse.id, + question: originalThreadResponse.question, + rerun: true, + adjustmentPayload: { + originalThreadResponseId: originalThreadResponse.id, + retrievedTables: adjustment.payload?.retrievedTables, + sqlGenerationReasoning: adjustment.payload?.sqlGenerationReasoning, + }, + } as TrackedTask; + this.trackedTasks.set(queryId, task); + this.trackedTasksById.set(currentThreadResponse.id, task); + + logger.info(`Rerun adjustment task with queryId: ${queryId}`); + return { queryId }; + } + + public async getAdjustmentResult( + queryId: string, + ): Promise { + // Check if we're tracking this task in memory + const trackedTask = this.trackedTasks.get(queryId); + + if (trackedTask && trackedTask.result) { + return { + ...trackedTask.result, + queryId, + taskId: trackedTask.taskId, + }; + } + + // If not in memory or no result yet, check the database + return this.getAdjustmentResultFromDB({ queryId }); + } + + public async getAdjustmentResultById( + id: number, + ): Promise { + const task = this.trackedTasksById.get(id); + if (task) { + return this.getAdjustmentResult(task.queryId); + } + + return this.getAdjustmentResultFromDB({ taskId: id }); + } + + public async cancelAdjustmentTask(queryId: string): Promise { + await this.wrenAIAdaptor.cancelAskFeedback(queryId); + + // telemetry + const eventName = TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_CANCEL; + this.telemetry.sendEvent(eventName, { + queryId, + }); + } + + public stopPolling(): void { + if (this.pollingIntervalId) { + clearInterval(this.pollingIntervalId); + } + } + + private startPolling(): void { + this.pollingIntervalId = setInterval(() => { + this.pollTasks(); + }, this.pollingInterval); + } + + private async pollTasks(): Promise { + const now = Date.now(); + const tasksToRemove: string[] = []; + + // Create an array of job functions + const jobs = Array.from(this.trackedTasks.entries()).map( + ([queryId, task]) => + async () => { + try { + // Skip if the job is already running + if (this.runningJobs.has(queryId)) { + return; + } + + // Skip finalized tasks that have been in memory too long + if ( + task.isFinalized && + now - task.lastPolled > this.memoryRetentionTime + ) { + tasksToRemove.push(queryId); + return; + } + + // Skip finalized tasks + if (task.isFinalized) { + return; + } + + // Mark the job as running + this.runningJobs.add(queryId); + + // Poll for updates + logger.info(`Polling for updates for task ${queryId}`); + const result = + await this.wrenAIAdaptor.getAskFeedbackResult(queryId); + task.lastPolled = now; + + // if result is not changed, we don't need to update the database + if (!this.isResultChanged(task.result, result)) { + this.runningJobs.delete(queryId); + return; + } + + // update task in memory if any change + task.result = result; + + // update the database + logger.info(`Updating task ${queryId} in database`); + await this.updateTaskInDatabase({ queryId }, task); + + // Check if task is now finalized + if (this.isTaskFinalized(result.status)) { + task.isFinalized = true; + // update thread response if threadResponseId is provided + if (task.threadResponseId) { + await this.updateThreadResponseWhenTaskFinalized(task); + } + + // telemetry + const eventName = task.rerun + ? TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_RERUN + : TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE; + const eventProperties = { + taskId: task.taskId, + queryId: task.queryId, + status: result.status, + error: result.error, + adjustmentPayload: task.adjustmentPayload, + }; + if (result.status === AskFeedbackStatus.FINISHED) { + this.telemetry.sendEvent(eventName, eventProperties); + } else { + this.telemetry.sendEvent( + eventName, + eventProperties, + WrenService.AI, + false, + ); + } + + logger.info( + `Task ${queryId} is finalized with status: ${result.status}`, + ); + } + + // Mark the job as finished + this.runningJobs.delete(queryId); + } catch (err) { + this.runningJobs.delete(queryId); + logger.error(err.stack); + throw err; + } + }, + ); + + // Run all jobs in parallel + Promise.allSettled(jobs.map((job) => job())).then((results) => { + // Log any rejected promises + results.forEach((result, index) => { + if (result.status === 'rejected') { + logger.error(`Job ${index} failed: ${result.reason}`); + } + }); + + // Clean up tasks that have been in memory too long + if (tasksToRemove.length > 0) { + logger.info( + `Cleaning up tasks that have been in memory too long. Tasks: ${tasksToRemove.join( + ', ', + )}`, + ); + } + for (const queryId of tasksToRemove) { + this.trackedTasks.delete(queryId); + } + }); + } + + private async updateThreadResponseWhenTaskFinalized( + task: TrackedTask, + ): Promise { + const response = task?.result?.response?.[0]; + if (!response) { + return; + } + await this.threadResponseRepository.updateOne(task.threadResponseId, { + sql: response?.sql, + }); + } + + private async getAdjustmentResultFromDB({ + queryId, + taskId, + }: { + queryId?: string; + taskId?: number; + }): Promise { + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + return null; + } + + return { + ...(taskRecord?.detail as AskFeedbackResult), + queryId: queryId || taskRecord?.queryId, + taskId: taskRecord?.id, + }; + } + + private async updateTaskInDatabase( + filter: { queryId?: string; taskId?: number }, + trackedTask: TrackedTask, + ): Promise { + const { queryId, taskId } = filter; + let taskRecord: AskingTask | null = null; + if (queryId) { + taskRecord = await this.askingTaskRepository.findByQueryId(queryId); + } else if (taskId) { + taskRecord = await this.askingTaskRepository.findOneBy({ id: taskId }); + } + + if (!taskRecord) { + throw new Error('Asking task not found'); + } + + // update the task + await this.askingTaskRepository.updateOne(taskRecord.id, { + detail: { + adjustment: true, + ...trackedTask.result, + }, + }); + } + + private isTaskFinalized(status: AskFeedbackStatus): boolean { + return [ + AskFeedbackStatus.FINISHED, + AskFeedbackStatus.FAILED, + AskFeedbackStatus.STOPPED, + ].includes(status); + } + + private isResultChanged( + previousResult: AskFeedbackResult, + newResult: AskFeedbackResult, + ): boolean { + // check status change + if (previousResult?.status !== newResult.status) { + return true; + } + + return false; + } +} diff --git a/wren-ui/src/apollo/server/backgrounds/index.ts b/wren-ui/src/apollo/server/backgrounds/index.ts index 3b3438be62..d6e6ef8131 100644 --- a/wren-ui/src/apollo/server/backgrounds/index.ts +++ b/wren-ui/src/apollo/server/backgrounds/index.ts @@ -1,2 +1,3 @@ export * from './recommend-question'; export * from './chart'; +export * from './adjustmentBackgroundTracker'; diff --git a/wren-ui/src/apollo/server/models/adaptor.ts b/wren-ui/src/apollo/server/models/adaptor.ts index b75efd649f..0af534fb53 100644 --- a/wren-ui/src/apollo/server/models/adaptor.ts +++ b/wren-ui/src/apollo/server/models/adaptor.ts @@ -291,3 +291,31 @@ export interface InstructionResult { status: InstructionStatus; error?: WrenAIError; } + +// ask feedback +export interface AskFeedbackInput { + tables: string[]; + sqlGenerationReasoning: string; + sql: string; + projectId: number; + configurations?: ProjectConfigurations; +} + +export enum AskFeedbackStatus { + UNDERSTANDING = 'UNDERSTANDING', + GENERATING = 'GENERATING', + CORRECTING = 'CORRECTING', + FINISHED = 'FINISHED', + FAILED = 'FAILED', + STOPPED = 'STOPPED', +} + +export interface AskFeedbackResult { + status: AskFeedbackStatus; + error?: WrenAIError; + response: Array<{ + type: AskCandidateType.LLM; + sql: string; + }>; + traceId?: string; +} diff --git a/wren-ui/src/apollo/server/models/model.ts b/wren-ui/src/apollo/server/models/model.ts index afb1029f79..fd763d3de6 100644 --- a/wren-ui/src/apollo/server/models/model.ts +++ b/wren-ui/src/apollo/server/models/model.ts @@ -91,7 +91,7 @@ export interface CheckCalculatedFieldCanQueryData { export interface PreviewSQLData { sql: string; - projectId?: number; + projectId?: string; limit?: number; dryRun?: boolean; } diff --git a/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts index 889bbcf5d0..942f445ee1 100644 --- a/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts +++ b/wren-ui/src/apollo/server/repositories/askingTaskRepository.ts @@ -7,13 +7,19 @@ import { mapValues, snakeCase, } from 'lodash'; -import { AskResult } from '../models/adaptor'; +import { AskFeedbackResult, AskResult } from '../models/adaptor'; + +export type AskingTaskDetail = + | AskResult + | (AskFeedbackResult & { + adjustment?: boolean; + }); export interface AskingTask { id: number; queryId: string; question?: string; - detail?: AskResult; + detail?: AskingTaskDetail; threadId?: number; threadResponseId?: number; createdAt: Date; diff --git a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts index 87634989ac..f51d1eda91 100644 --- a/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts +++ b/wren-ui/src/apollo/server/repositories/threadResponseRepository.ts @@ -39,6 +39,29 @@ export interface ThreadResponseChartDetail { adjustment?: boolean; } +export enum ThreadResponseAdjustmentType { + REASONING = 'REASONING', + APPLY_SQL = 'APPLY_SQL', +} + +export type ThreadResponseAdjustmentReasoningPayload = { + originalThreadResponseId?: number; + retrievedTables?: string[]; + sqlGenerationReasoning?: string; +}; + +export type ThreadResponseAdjustmentApplySqlPayload = { + originalThreadResponseId?: number; + sql?: string; +}; + +export interface ThreadResponseAdjustment { + type: ThreadResponseAdjustmentType; + // todo: I think we could use a better way to do this instead of using a union type + payload: ThreadResponseAdjustmentReasoningPayload & + ThreadResponseAdjustmentApplySqlPayload; +} + export interface ThreadResponse { id: number; // ID askingTaskId?: number; // Reference to asking_task.id @@ -49,6 +72,7 @@ export interface ThreadResponse { answerDetail?: ThreadResponseAnswerDetail; // AI generated text-based answer detail breakdownDetail?: ThreadResponseBreakdownDetail; // Thread response breakdown detail chartDetail?: ThreadResponseChartDetail; // Thread response chart detail + adjustment?: ThreadResponseAdjustment; // Thread response adjustment } export interface IThreadResponseRepository @@ -67,6 +91,7 @@ export class ThreadResponseRepository 'answerDetail', 'breakdownDetail', 'chartDetail', + 'adjustment', ]; constructor(knexPg: Knex) { @@ -102,11 +127,16 @@ export class ThreadResponseRepository res.chartDetail && typeof res.chartDetail === 'string' ? JSON.parse(res.chartDetail) : res.chartDetail; + const adjustment = + res.adjustment && typeof res.adjustment === 'string' + ? JSON.parse(res.adjustment) + : res.adjustment; return { ...res, answerDetail: answerDetail || null, breakdownDetail: breakdownDetail || null, chartDetail: chartDetail || null, + adjustment: adjustment || null, }; }) as ThreadResponse[]; } @@ -120,6 +150,7 @@ export class ThreadResponseRepository answerDetail: ThreadResponseAnswerDetail; breakdownDetail: ThreadResponseBreakdownDetail; chartDetail: ThreadResponseChartDetail; + adjustment: ThreadResponseAdjustment; }>, queryOptions?: IQueryOptions, ) { @@ -136,6 +167,7 @@ export class ThreadResponseRepository chartDetail: data.chartDetail ? JSON.stringify(data.chartDetail) : undefined, + adjustment: data.adjustment ? JSON.stringify(data.adjustment) : undefined, }; const executer = queryOptions?.tx ? queryOptions.tx : this.knex; const [result] = await executer(this.tableName) diff --git a/wren-ui/src/apollo/server/resolvers.ts b/wren-ui/src/apollo/server/resolvers.ts index 2e91bf22ac..fc01d77279 100644 --- a/wren-ui/src/apollo/server/resolvers.ts +++ b/wren-ui/src/apollo/server/resolvers.ts @@ -34,6 +34,9 @@ const resolvers = { suggestedQuestions: askingResolver.getSuggestedQuestions, instantRecommendedQuestions: askingResolver.getInstantRecommendedQuestions, + // Adjustment + adjustmentTask: askingResolver.getAdjustmentTask, + // Thread thread: askingResolver.getThread, threads: askingResolver.listThreads, @@ -97,6 +100,11 @@ const resolvers = { askingResolver.createInstantRecommendedQuestions, rerunAskingTask: askingResolver.rerunAskingTask, + // Adjustment + adjustThreadResponse: askingResolver.adjustThreadResponse, + cancelAdjustmentTask: askingResolver.cancelAdjustThreadResponseAnswer, + rerunAdjustmentTask: askingResolver.rerunAdjustThreadResponseAnswer, + // Thread createThread: askingResolver.createThread, updateThread: askingResolver.updateThread, diff --git a/wren-ui/src/apollo/server/resolvers/askingResolver.ts b/wren-ui/src/apollo/server/resolvers/askingResolver.ts index 31e499ec43..e67fb451b3 100644 --- a/wren-ui/src/apollo/server/resolvers/askingResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/askingResolver.ts @@ -5,6 +5,7 @@ import { AskResultType, RecommendationQuestionStatus, ChartAdjustmentOption, + AskFeedbackStatus, } from '@server/models/adaptor'; import { Thread } from '../repositories/threadRepository'; import { @@ -39,6 +40,14 @@ export interface Task { id: string; } +export interface AdjustmentTask { + queryId: string; + status: AskFeedbackStatus; + error: WrenAIError | null; + sql: string; + traceId: string; +} + export interface AskingTask { type: AskResultType | null; status: AskResultStatus; @@ -108,6 +117,13 @@ export class AskingResolver { this.generateThreadResponseChart.bind(this); this.adjustThreadResponseChart = this.adjustThreadResponseChart.bind(this); this.transformAskingTask = this.transformAskingTask.bind(this); + + this.adjustThreadResponse = this.adjustThreadResponse.bind(this); + this.cancelAdjustThreadResponseAnswer = + this.cancelAdjustThreadResponseAnswer.bind(this); + this.rerunAdjustThreadResponseAnswer = + this.rerunAdjustThreadResponseAnswer.bind(this); + this.getAdjustmentTask = this.getAdjustmentTask.bind(this); } public async generateProjectRecommendationQuestions( @@ -304,6 +320,7 @@ export class AskingResolver { breakdownDetail: response.breakdownDetail, answerDetail: response.answerDetail, chartDetail: response.chartDetail, + adjustment: response.adjustment, }); return acc; @@ -448,6 +465,98 @@ export class AskingResolver { return task; } + public async adjustThreadResponse( + _root: any, + args: { + responseId: number; + data: { + tables?: string[]; + sqlGenerationReasoning?: string; + sql?: string; + }; + }, + ctx: IContext, + ): Promise { + const { responseId, data } = args; + const askingService = ctx.askingService; + const project = await ctx.projectService.getCurrentProject(); + + if (data.sql) { + const response = await askingService.adjustThreadResponseWithSQL( + responseId, + { + sql: data.sql, + }, + ); + ctx.telemetry.sendEvent( + TelemetryEvent.HOME_ADJUST_THREAD_RESPONSE_WITH_SQL, + { + sql: data.sql, + responseId, + }, + ); + return response; + } + + return askingService.adjustThreadResponseAnswer( + responseId, + { + projectId: project.id, + tables: data.tables, + sqlGenerationReasoning: data.sqlGenerationReasoning, + }, + { + language: WrenAILanguage[project.language] || WrenAILanguage.EN, + }, + ); + } + + public async cancelAdjustThreadResponseAnswer( + _root: any, + args: { taskId: string }, + ctx: IContext, + ): Promise { + const { taskId } = args; + const askingService = ctx.askingService; + await askingService.cancelAdjustThreadResponseAnswer(taskId); + return true; + } + + public async rerunAdjustThreadResponseAnswer( + _root: any, + args: { responseId: number }, + ctx: IContext, + ): Promise { + const { responseId } = args; + const askingService = ctx.askingService; + const project = await ctx.projectService.getCurrentProject(); + await askingService.rerunAdjustThreadResponseAnswer( + responseId, + project.id, + { + language: WrenAILanguage[project.language] || WrenAILanguage.EN, + }, + ); + return true; + } + + public async getAdjustmentTask( + _root: any, + args: { taskId: string }, + ctx: IContext, + ): Promise { + const { taskId } = args; + const askingService = ctx.askingService; + const adjustmentTask = await askingService.getAdjustmentTask(taskId); + return { + queryId: adjustmentTask?.queryId, + status: adjustmentTask?.status, + error: adjustmentTask?.error, + sql: adjustmentTask?.response?.[0]?.sql, + traceId: adjustmentTask?.traceId, + }; + } + public async generateThreadResponseBreakdown( _root: any, args: { responseId: number }, @@ -604,6 +713,9 @@ export class AskingResolver { return parent.sql ? format(parent.sql) : null; }, askingTask: async (parent: ThreadResponse, _args: any, ctx: IContext) => { + if (parent.adjustment) { + return null; + } const askingService = ctx.askingService; const askingTask = await askingService.getAskingTaskById( parent.askingTaskId, @@ -611,6 +723,27 @@ export class AskingResolver { if (!askingTask) return null; return this.transformAskingTask(askingTask, ctx); }, + adjustmentTask: async ( + parent: ThreadResponse, + _args: any, + ctx: IContext, + ): Promise => { + if (!parent.adjustment) { + return null; + } + const askingService = ctx.askingService; + const adjustmentTask = await askingService.getAdjustmentTaskById( + parent.askingTaskId, + ); + if (!adjustmentTask) return null; + return { + queryId: adjustmentTask?.queryId, + status: adjustmentTask?.status, + error: adjustmentTask?.error, + sql: adjustmentTask?.response?.[0]?.sql, + traceId: adjustmentTask?.traceId, + }; + }, }); public getDetailStepNestedResolver = () => ({ diff --git a/wren-ui/src/apollo/server/resolvers/modelResolver.ts b/wren-ui/src/apollo/server/resolvers/modelResolver.ts index 5711d70765..c80b511690 100644 --- a/wren-ui/src/apollo/server/resolvers/modelResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/modelResolver.ts @@ -932,7 +932,7 @@ export class ModelResolver { ) { const { sql, projectId, limit, dryRun } = args.data; const project = projectId - ? await ctx.projectService.getProjectById(projectId) + ? await ctx.projectService.getProjectById(parseInt(projectId)) : await ctx.projectService.getCurrentProject(); const { manifest } = await ctx.deployService.getLastDeployment(project.id); return await ctx.queryService.preview(sql, { diff --git a/wren-ui/src/apollo/server/resolvers/projectResolver.ts b/wren-ui/src/apollo/server/resolvers/projectResolver.ts index e07bdb6a8e..c8e6022989 100644 --- a/wren-ui/src/apollo/server/resolvers/projectResolver.ts +++ b/wren-ui/src/apollo/server/resolvers/projectResolver.ts @@ -130,6 +130,7 @@ export class ProjectResolver { await ctx.modelService.deleteAllViewsByProjectId(id); await ctx.modelService.deleteAllModelsByProjectId(id); await ctx.projectService.deleteProject(id); + await ctx.wrenAIAdaptor.delete(id); // telemetry ctx.telemetry.sendEvent(eventName, { diff --git a/wren-ui/src/apollo/server/schema.ts b/wren-ui/src/apollo/server/schema.ts index 5e9012723d..2a7835d4f2 100644 --- a/wren-ui/src/apollo/server/schema.ts +++ b/wren-ui/src/apollo/server/schema.ts @@ -656,6 +656,12 @@ export const typeDefs = gql` theta: String } + input AdjustThreadResponseInput { + tables: [String!] + sqlGenerationReasoning: String + sql: String + } + input PreviewDataInput { responseId: Int! # Optional, only used for preview data of a single step @@ -707,6 +713,24 @@ export const typeDefs = gql` adjustment: Boolean } + enum ThreadResponseAdjustmentType { + REASONING + APPLY_SQL + } + + type ThreadResponseAdjustment { + type: ThreadResponseAdjustmentType! + payload: JSON + } + + type AdjustmentTask { + queryId: String + status: AskingTaskStatus + error: Error + sql: String + traceId: String + } + type ThreadResponse { id: Int! threadId: Int! @@ -717,6 +741,8 @@ export const typeDefs = gql` answerDetail: ThreadResponseAnswerDetail chartDetail: ThreadResponseChartDetail askingTask: AskingTask + adjustment: ThreadResponseAdjustment + adjustmentTask: AdjustmentTask } # Thread only consists of basic information of a thread @@ -762,7 +788,7 @@ export const typeDefs = gql` input PreviewSQLDataInput { sql: String! - projectId: Int + projectId: String limit: Int dryRun: Boolean } @@ -955,6 +981,9 @@ export const typeDefs = gql` threadResponse(responseId: Int!): ThreadResponse! nativeSql(responseId: Int!): String! + # Adjustment + adjustmentTask(taskId: String!): AdjustmentTask + # Settings settings: Settings! @@ -1066,6 +1095,14 @@ export const typeDefs = gql` data: AdjustThreadResponseChartInput! ): ThreadResponse! + # Adjustment + adjustThreadResponse( + responseId: Int! + data: AdjustThreadResponseInput! + ): ThreadResponse! + cancelAdjustmentTask(taskId: String!): Boolean! + rerunAdjustmentTask(responseId: Int!): Boolean! + # Settings resetCurrentProject: Boolean! updateCurrentProject(data: UpdateCurrentProjectInput!): Boolean! diff --git a/wren-ui/src/apollo/server/services/askingService.ts b/wren-ui/src/apollo/server/services/askingService.ts index 78eed90901..055caec9e0 100644 --- a/wren-ui/src/apollo/server/services/askingService.ts +++ b/wren-ui/src/apollo/server/services/askingService.ts @@ -16,6 +16,7 @@ import { IThreadRepository, Thread } from '../repositories/threadRepository'; import { IThreadResponseRepository, ThreadResponse, + ThreadResponseAdjustmentType, } from '../repositories/threadResponseRepository'; import { getLogger } from '@server/utils'; import { isEmpty, isNil } from 'lodash'; @@ -25,13 +26,19 @@ import { TelemetryEvent, WrenService, } from '../telemetry/telemetry'; -import { IViewRepository, Project } from '../repositories'; +import { + IAskingTaskRepository, + IViewRepository, + Project, +} from '../repositories'; import { IQueryService, PreviewDataResponse } from './queryService'; import { IMDLService } from './mdlService'; import { ThreadRecommendQuestionBackgroundTracker, ChartBackgroundTracker, ChartAdjustmentBackgroundTracker, + AdjustmentBackgroundTaskTracker, + TrackedAdjustmentResult, } from '../backgrounds'; import { getConfig } from '@server/config'; import { TextBasedAnswerBackgroundTracker } from '../backgrounds/textBasedAnswerBackgroundTracker'; @@ -94,6 +101,17 @@ export enum ThreadResponseAnswerStatus { INTERRUPTED = 'INTERRUPTED', } +// adjustment input +export interface AdjustmentReasoningInput { + tables: string[]; + sqlGenerationReasoning: string; + projectId: number; +} + +export interface AdjustmentSqlInput { + sql: string; +} + export interface IAskingService { /** * Asking task. @@ -155,6 +173,23 @@ export interface IAskingService { input: ChartAdjustmentOption, configurations: { language: string }, ): Promise; + adjustThreadResponseWithSQL( + threadResponseId: number, + input: AdjustmentSqlInput, + ): Promise; + adjustThreadResponseAnswer( + threadResponseId: number, + input: AdjustmentReasoningInput, + configurations: { language: string }, + ): Promise; + cancelAdjustThreadResponseAnswer(taskId: string): Promise; + rerunAdjustThreadResponseAnswer( + threadResponseId: number, + projectId: number, + configurations: { language: string }, + ): Promise<{ queryId: string }>; + getAdjustmentTask(taskId: string): Promise; + getAdjustmentTaskById(id: number): Promise; changeThreadResponseAnswerDetailStatus( responseId: number, status: ThreadResponseAnswerStatus, @@ -379,6 +414,8 @@ export class AskingService implements IAskingService { private telemetry: PostHogTelemetry; private mdlService: IMDLService; private askingTaskTracker: IAskingTaskTracker; + private askingTaskRepository: IAskingTaskRepository; + private adjustmentBackgroundTracker: AdjustmentBackgroundTaskTracker; constructor({ telemetry, @@ -388,6 +425,7 @@ export class AskingService implements IAskingService { viewRepository, threadRepository, threadResponseRepository, + askingTaskRepository, queryService, mdlService, askingTaskTracker, @@ -399,6 +437,7 @@ export class AskingService implements IAskingService { viewRepository: IViewRepository; threadRepository: IThreadRepository; threadResponseRepository: IThreadResponseRepository; + askingTaskRepository: IAskingTaskRepository; queryService: IQueryService; mdlService: IMDLService; askingTaskTracker: IAskingTaskTracker; @@ -441,7 +480,14 @@ export class AskingService implements IAskingService { wrenAIAdaptor, threadRepository, }); + this.adjustmentBackgroundTracker = new AdjustmentBackgroundTaskTracker({ + telemetry, + wrenAIAdaptor, + askingTaskRepository, + threadResponseRepository, + }); + this.askingTaskRepository = askingTaskRepository; this.mdlService = mdlService; this.askingTaskTracker = askingTaskTracker; } @@ -1013,6 +1059,97 @@ export class AskingService implements IAskingService { return lastDeploy.hash; } + public async adjustThreadResponseWithSQL( + threadResponseId: number, + input: AdjustmentSqlInput, + ): Promise { + const response = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!response) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + return await this.threadResponseRepository.createOne({ + sql: input.sql, + threadId: response.threadId, + question: response.question, + adjustment: { + type: ThreadResponseAdjustmentType.APPLY_SQL, + payload: { + originalThreadResponseId: response.id, + sql: input.sql, + }, + }, + }); + } + + public async adjustThreadResponseAnswer( + threadResponseId: number, + input: AdjustmentReasoningInput, + configurations: { language: string }, + ): Promise { + const originalThreadResponse = + await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!originalThreadResponse) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + const { createdThreadResponse } = + await this.adjustmentBackgroundTracker.createAdjustmentTask({ + threadId: originalThreadResponse.threadId, + tables: input.tables, + sqlGenerationReasoning: input.sqlGenerationReasoning, + sql: originalThreadResponse.sql, + projectId: input.projectId, + configurations, + question: originalThreadResponse.question, + originalThreadResponseId: originalThreadResponse.id, + }); + return createdThreadResponse; + } + + public async cancelAdjustThreadResponseAnswer(taskId: string): Promise { + // call cancelAskFeedback on AI service + await this.adjustmentBackgroundTracker.cancelAdjustmentTask(taskId); + } + + public async rerunAdjustThreadResponseAnswer( + threadResponseId: number, + projectId: number, + configurations: { language: string }, + ): Promise<{ queryId: string }> { + const threadResponse = await this.threadResponseRepository.findOneBy({ + id: threadResponseId, + }); + if (!threadResponse) { + throw new Error(`Thread response ${threadResponseId} not found`); + } + + const { queryId } = + await this.adjustmentBackgroundTracker.rerunAdjustmentTask({ + threadId: threadResponse.threadId, + threadResponseId, + projectId, + configurations, + }); + return { queryId }; + } + + public async getAdjustmentTask( + taskId: string, + ): Promise { + return this.adjustmentBackgroundTracker.getAdjustmentResult(taskId); + } + + public async getAdjustmentTaskById( + id: number, + ): Promise { + return this.adjustmentBackgroundTracker.getAdjustmentResultById(id); + } + /** * Get the thread response of a thread for asking * @param threadId diff --git a/wren-ui/src/apollo/server/services/askingTaskTracker.ts b/wren-ui/src/apollo/server/services/askingTaskTracker.ts index 2ed0274abe..a2f18a61f0 100644 --- a/wren-ui/src/apollo/server/services/askingTaskTracker.ts +++ b/wren-ui/src/apollo/server/services/askingTaskTracker.ts @@ -404,7 +404,7 @@ export class AskingTaskTracker implements IAskingTaskTracker { } return { - ...taskRecord?.detail, + ...(taskRecord?.detail as AskResult), queryId: queryId || taskRecord?.queryId, question: taskRecord?.question, taskId: taskRecord?.id, diff --git a/wren-ui/src/apollo/server/telemetry/telemetry.ts b/wren-ui/src/apollo/server/telemetry/telemetry.ts index 99c469d54a..930982013b 100644 --- a/wren-ui/src/apollo/server/telemetry/telemetry.ts +++ b/wren-ui/src/apollo/server/telemetry/telemetry.ts @@ -56,6 +56,12 @@ export enum TelemetryEvent { HOME_GENERATE_PROJECT_RECOMMENDATION_QUESTIONS = 'home_generate_project_recommendation_questions', HOME_GENERATE_THREAD_RECOMMENDATION_QUESTIONS = 'home_generate_thread_recommendation_questions', + // adjustment + HOME_ADJUST_THREAD_RESPONSE = 'home_adjust_thread_response', + HOME_ADJUST_THREAD_RESPONSE_CANCEL = 'home_adjust_thread_response_cancel', + HOME_ADJUST_THREAD_RESPONSE_RERUN = 'home_adjust_thread_response_rerun', + HOME_ADJUST_THREAD_RESPONSE_WITH_SQL = 'home_adjust_thread_response_with_sql', + // event after ask HOME_CREATE_VIEW = 'home_create_view', HOME_PREVIEW_ANSWER = 'home_preview_answer', diff --git a/wren-ui/src/apollo/server/utils/error.ts b/wren-ui/src/apollo/server/utils/error.ts index 58fc196401..87b8069033 100644 --- a/wren-ui/src/apollo/server/utils/error.ts +++ b/wren-ui/src/apollo/server/utils/error.ts @@ -41,6 +41,9 @@ export enum GeneralErrorCodes { GENERATE_QUESTIONS_ERROR = 'GENERATE_QUESTIONS_ERROR', INVALID_SQL_ERROR = 'INVALID_SQL_ERROR', + // wren engine error + WREN_ENGINE_ERROR = 'WREN_ENGINE_ERROR', + // asking task error // when rerun from cancelled, the task is identified as general or misleading query IDENTIED_AS_GENERAL = 'IDENTIED_AS_GENERAL', diff --git a/wren-ui/src/common.ts b/wren-ui/src/common.ts index aee4518976..8337b2b6c2 100644 --- a/wren-ui/src/common.ts +++ b/wren-ui/src/common.ts @@ -128,6 +128,7 @@ export const initComponents = () => { queryService, mdlService, askingTaskTracker, + askingTaskRepository, }); const dashboardService = new DashboardService({ projectService, diff --git a/wren-ui/src/hooks/useAskPrompt.tsx b/wren-ui/src/hooks/useAskPrompt.tsx index 3f5c2b0f67..e9f242874b 100644 --- a/wren-ui/src/hooks/useAskPrompt.tsx +++ b/wren-ui/src/hooks/useAskPrompt.tsx @@ -228,12 +228,14 @@ export default function useAskPrompt(threadId?: number) { checkFetchAskingStreamTask(askingTask); } } + }, [askingTask?.status, threadId, checkFetchAskingStreamTask]); + useEffect(() => { // handle instant recommended questions if (isNeedRecommendedQuestions(askingTask)) { startRecommendedQuestions(); } - }, [askingTask, threadId, checkFetchAskingStreamTask]); + }, [askingTask?.type]); useEffect(() => { if (isRecommendedFinished(recommendedQuestions?.status))