From a74b9325367bcf3690251959f5f1c8500b3b3c1e Mon Sep 17 00:00:00 2001 From: Philip Langer Date: Thu, 14 Nov 2024 20:24:29 +0100 Subject: [PATCH 1/3] feat(ai): agents can ask for input and continue * Allow custom chat agents to stop completing the response conditionally * Introduce an orthogonal response state called `waitingForInput` * Introduce `show` setting on progress messages to control visibility * 'untilFirstContent': Disappears when first response content appears * 'whileIncomplete': Remains visible while response is incomplete * 'forever': Remains visible forever. * Adds a `QuestionResponseContent` and `QuestionPartRenderer` * Adds an API example agent 'AskAndContinue' that uses these features * Introduces agent-specific content matchers (in contrast to globals) * Dels redundant response completion & recording in `AbstractChatAgent` Contributed on behalf of STMicroelectronics. --- examples/api-samples/package.json | 2 + .../browser/api-samples-frontend-module.ts | 2 + ...sk-and-continue-chat-agent-contribution.ts | 117 ++++++++++++++++++ examples/api-samples/tsconfig.json | 6 + .../src/browser/ai-chat-ui-frontend-module.ts | 2 + .../question-part-renderer.tsx | 59 +++++++++ .../chat-tree-view/chat-view-tree-widget.tsx | 26 +++- .../ai-chat-ui/src/browser/style/index.css | 29 ++++- packages/ai-chat/src/common/chat-agents.ts | 50 +++++--- .../ai-chat/src/common/chat-model-util.ts | 44 +++++++ packages/ai-chat/src/common/chat-model.ts | 61 +++++++++ packages/ai-chat/src/common/index.ts | 1 + .../ai-chat/src/common/parse-contents.spec.ts | 30 ++--- packages/ai-chat/src/common/parse-contents.ts | 9 +- .../src/common/response-content-matcher.ts | 7 +- 15 files changed, 402 insertions(+), 43 deletions(-) create mode 100644 examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts create mode 100644 packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx create mode 100644 packages/ai-chat/src/common/chat-model-util.ts diff --git a/examples/api-samples/package.json b/examples/api-samples/package.json index 7ff6acaadb5b0..999ac34cd4407 100644 --- a/examples/api-samples/package.json +++ b/examples/api-samples/package.json @@ -4,6 +4,8 @@ "version": "1.55.0", "description": "Theia - Example code to demonstrate Theia API", "dependencies": { + "@theia/ai-core": "1.55.0", + "@theia/ai-chat": "1.55.0", "@theia/ai-chat-ui": "1.55.0", "@theia/core": "1.55.0", "@theia/file-search": "1.55.0", diff --git a/examples/api-samples/src/browser/api-samples-frontend-module.ts b/examples/api-samples/src/browser/api-samples-frontend-module.ts index fc41efb1bce03..01ae0987bd7ec 100644 --- a/examples/api-samples/src/browser/api-samples-frontend-module.ts +++ b/examples/api-samples/src/browser/api-samples-frontend-module.ts @@ -31,6 +31,7 @@ import { bindSampleAppInfo } from './vsx/sample-frontend-app-info'; import { bindTestSample } from './test/sample-test-contribution'; import { bindSampleFileSystemCapabilitiesCommands } from './file-system/sample-file-system-capabilities'; import { bindChatNodeToolbarActionContribution } from './chat/chat-node-toolbar-action-contribution'; +import { bindAskAndContinueChatAgentContribution } from './chat/ask-and-continue-chat-agent-contribution'; export default new ContainerModule(( bind: interfaces.Bind, @@ -38,6 +39,7 @@ export default new ContainerModule(( isBound: interfaces.IsBound, rebind: interfaces.Rebind, ) => { + bindAskAndContinueChatAgentContribution(bind); bindChatNodeToolbarActionContribution(bind); bindDynamicLabelProvider(bind); bindSampleUnclosableView(bind); diff --git a/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts new file mode 100644 index 0000000000000..691e78f894db9 --- /dev/null +++ b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts @@ -0,0 +1,117 @@ +// ***************************************************************************** +// Copyright (C) 2024 STMicroelectronics and others. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** + +import { + AbstractStreamParsingChatAgent, + ChatAgent, + ChatRequestModelImpl, + lastProgressMessage, + QuestionResponseContent, + SystemMessageDescription, + unansweredQuestions +} from '@theia/ai-chat'; +import { Agent, PromptTemplate } from '@theia/ai-core'; +import { injectable, interfaces } from '@theia/core/shared/inversify'; + +export function bindAskAndContinueChatAgentContribution(bind: interfaces.Bind): void { + bind(AskAndContinueChatAgent).toSelf().inSingletonScope(); + bind(Agent).toService(AskAndContinueChatAgent); + bind(ChatAgent).toService(AskAndContinueChatAgent); +} + +const systemPrompt: PromptTemplate = { + id: 'askAndContinue-system', + template: ` +Whatever the user inputs, you will write one arbitrary sentence and then ask a question with +the following format and two or three options: + + +{ + "question": "YOUR QUESTION HERE", + "options": [ + { + "text": "OPTION 1" + }, + { + "text": "OPTION 2" + } + ] +} + + ` +}; + +@injectable() +export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent implements ChatAgent { + override id = 'AskAndContinue'; + readonly name = 'AskAndContinue'; + override defaultLanguageModelPurpose = 'chat'; + readonly description = 'What ever you input, this chat will ask a question and continues after that.'; + readonly variables = []; + readonly agentSpecificVariables = []; + readonly functions = []; + + override additionalContentMatchers = [ + { + start: /^.*$/m, + end: /^<\/question>$/m, + contentFactory: (content: string, request: ChatRequestModelImpl) => { + const question = content.replace(/^\n|<\/question>$/g, ''); + const parsedQuestion = JSON.parse(question); + return { + kind: 'question', + question: parsedQuestion.question, + options: parsedQuestion.options, + request, + handler: (option, _request) => this.handleAnswer(option, _request) + }; + } + } + ]; + + override languageModelRequirements = [ + { + purpose: 'chat', + identifier: 'openai/gpt-4o', + } + ]; + + readonly promptTemplates = [systemPrompt]; + + protected override async getSystemMessageDescription(): Promise { + const resolvedPrompt = await this.promptService.getPrompt(systemPrompt.id); + return resolvedPrompt ? SystemMessageDescription.fromResolvedPromptTemplate(resolvedPrompt) : undefined; + } + + protected override async onResponseComplete(request: ChatRequestModelImpl): Promise { + const unansweredQs = unansweredQuestions(request); + if (unansweredQs.length < 1) { + return super.onResponseComplete(request); + } + request.response.addProgressMessage({ content: 'Waiting for input...', show: 'whileIncomplete' }); + request.response.waitForInput(); + } + + protected handleAnswer(selectedOption: { text: string; value?: string; }, request: ChatRequestModelImpl): void { + const progressMessage = lastProgressMessage(request); + if (progressMessage) { + request.response.updateProgressMessage({ ...progressMessage, show: 'untilFirstContent', status: 'completed' }); + } + request.response.continue(); + this.invoke(request); + } +} + diff --git a/examples/api-samples/tsconfig.json b/examples/api-samples/tsconfig.json index 551c17de9f91b..000f5e8c524f3 100644 --- a/examples/api-samples/tsconfig.json +++ b/examples/api-samples/tsconfig.json @@ -12,9 +12,15 @@ { "path": "../../dev-packages/ovsx-client" }, + { + "path": "../../packages/ai-chat" + }, { "path": "../../packages/ai-chat-ui" }, + { + "path": "../../packages/ai-core" + }, { "path": "../../packages/core" }, diff --git a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts index 285f2cadd42fc..b3c7e70f045f2 100644 --- a/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts +++ b/packages/ai-chat-ui/src/browser/ai-chat-ui-frontend-module.ts @@ -36,6 +36,7 @@ import { ChatViewLanguageContribution } from './chat-view-language-contribution' import { ChatViewWidget } from './chat-view-widget'; import { ChatViewWidgetToolbarContribution } from './chat-view-widget-toolbar-contribution'; import { EditorPreviewManager } from '@theia/editor-preview/lib/browser/editor-preview-manager'; +import { QuestionPartRenderer } from './chat-response-renderer/question-part-renderer'; export default new ContainerModule((bind, _unbind, _isBound, rebind) => { bindViewContribution(bind, AIChatContribution); @@ -66,6 +67,7 @@ export default new ContainerModule((bind, _unbind, _isBound, rebind) => { bind(ChatResponsePartRenderer).to(CommandPartRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(ToolCallPartRenderer).inSingletonScope(); bind(ChatResponsePartRenderer).to(ErrorPartRenderer).inSingletonScope(); + bind(ChatResponsePartRenderer).to(QuestionPartRenderer).inSingletonScope(); [CommandContribution, MenuContribution].forEach(serviceIdentifier => bind(serviceIdentifier).to(ChatViewMenuContribution).inSingletonScope() ); diff --git a/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx b/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx new file mode 100644 index 0000000000000..6af0a065392f1 --- /dev/null +++ b/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx @@ -0,0 +1,59 @@ +// ***************************************************************************** +// Copyright (C) 2024 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** +import { ChatResponseContent, QuestionResponseContent } from '@theia/ai-chat'; +import { injectable } from '@theia/core/shared/inversify'; +import * as React from '@theia/core/shared/react'; +import { ReactNode } from '@theia/core/shared/react'; +import { ChatResponsePartRenderer } from '../chat-response-part-renderer'; +import { ResponseNode } from '../chat-tree-view'; + +@injectable() +export class QuestionPartRenderer + implements ChatResponsePartRenderer { + + canHandle(response: ChatResponseContent): number { + if (QuestionResponseContent.is(response)) { + return 10; + } + return -1; + } + + render(question: QuestionResponseContent, node: ResponseNode): ReactNode { + return ( +
+
{question.question}
+
+ { + question.options.map((option, index) => ( + + )) + } +
+
+ ); + } + +} diff --git a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx index 668eefef7496d..175eec8547c07 100644 --- a/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx +++ b/packages/ai-chat-ui/src/browser/chat-tree-view/chat-view-tree-widget.tsx @@ -267,6 +267,7 @@ export class ChatViewTreeWidget extends TreeWidget { private renderAgent(node: RequestNode | ResponseNode): React.ReactNode { const inProgress = isResponseNode(node) && !node.response.isComplete && !node.response.isCanceled && !node.response.isError; + const waitingForInput = isResponseNode(node) && node.response.isWaitingForInput; const toolbarContributions = !inProgress ? this.chatNodeToolbarActionContributions.getContributions() .flatMap(c => c.getToolbarActions(node)) @@ -277,7 +278,8 @@ export class ChatViewTreeWidget extends TreeWidget {

{this.getAgentLabel(node)}

- {inProgress && Generating} + {inProgress && !waitingForInput && Generating} + {inProgress && waitingForInput && Waiting for input}
{!inProgress && toolbarContributions.length > 0 && @@ -340,12 +342,28 @@ export class ChatViewTreeWidget extends TreeWidget {
{!node.response.isComplete && node.response.response.content.length === 0 - && node.response.progressMessages.map((c, i) => - - )} + && node.response.progressMessages + .filter(c => c.show === 'untilFirstContent') + .map((c, i) => + + ) + } {node.response.response.content.map((c, i) =>
{this.getChatResponsePartRenderer(c, node)}
)} + {!node.response.isComplete + && node.response.progressMessages + .filter(c => c.show === 'whileIncomplete') + .map((c, i) => + + ) + } + {node.response.progressMessages + .filter(c => c.show === 'forever') + .map((c, i) => + + ) + }
); } diff --git a/packages/ai-chat-ui/src/browser/style/index.css b/packages/ai-chat-ui/src/browser/style/index.css index 4c86cfe272689..bd9dc56775816 100644 --- a/packages/ai-chat-ui/src/browser/style/index.css +++ b/packages/ai-chat-ui/src/browser/style/index.css @@ -231,7 +231,7 @@ div:last-child > .theia-ChatNode { display: flex; flex-direction: column; gap: 2px; - border: 1px solid var(--theia-input-border); + border: var(--theia-border-width) solid var(--theia-input-border); border-radius: 4px; } @@ -265,6 +265,33 @@ div:last-child > .theia-ChatNode { background-color: var(--theia-input-border); } +.theia-QuestionPartRenderer-root { + display: flex; + flex-direction: column; + gap: 8px; + border: var(--theia-border-width) solid + var(--theia-sideBarSectionHeader-border); + padding: 8px 12px 12px; + border-radius: 5px; + margin: 0 0 8px 0; +} +.theia-QuestionPartRenderer-options { + display: flex; + flex-wrap: wrap; + gap: 12px; +} +.theia-QuestionPartRenderer-option { + min-width: 100px; + flex: 1 1 auto; + margin: 0; +} +.theia-QuestionPartRenderer-option.selected:disabled:hover { + background-color: var(--theia-button-disabledBackground); +} +.theia-QuestionPartRenderer-option:disabled:not(.selected) { + background-color: var(--theia-button-secondaryBackground); +} + .theia-toolCall { font-weight: normal; color: var(--theia-descriptionForeground); diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index f5b9f1e735e99..9dcf9a029808f 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -128,6 +128,11 @@ export abstract class AbstractChatAgent { @inject(ContributionProvider) @named(ResponseContentMatcherProvider) protected contentMatcherProviders: ContributionProvider; protected contentMatchers: ResponseContentMatcher[] = []; + /** + * Agent-specific content matchers used by this agent in addition to the contributed content matchers. + * @see ResponseContentMatcherProvider + */ + protected additionalContentMatchers: ResponseContentMatcher[] = []; @inject(DefaultResponseContentFactory) protected defaultContentFactory: DefaultResponseContentFactory; @@ -144,7 +149,15 @@ export abstract class AbstractChatAgent { @postConstruct() init(): void { - this.contentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers); + this.initializeContentMatchers(); + } + + protected initializeContentMatchers(): void { + const contributedContentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers); + this.contentMatchers = [ + ...contributedContentMatchers, + ...this.additionalContentMatchers + ]; } async invoke(request: ChatRequestModelImpl): Promise { @@ -195,7 +208,7 @@ export abstract class AbstractChatAgent { cancellationToken.token ); await this.addContentsToResponse(languageModelResponse, request); - request.response.complete(); + await this.onResponseComplete(request); if (this.defaultLogging) { this.recordingService.recordResponse(ChatHistoryEntry.fromResponse(this.id, request)); } @@ -204,9 +217,10 @@ export abstract class AbstractChatAgent { } } - protected parseContents(text: string): ChatResponseContent[] { + protected parseContents(text: string, request: ChatRequestModelImpl): ChatResponseContent[] { return parseContents( text, + request, this.contentMatchers, this.defaultContentFactory?.create.bind(this.defaultContentFactory) ); @@ -290,6 +304,16 @@ export abstract class AbstractChatAgent { return undefined; } + /** + * Invoked after the response by the LLM completed successfully. + * + * The default implementation sets the state of the response to `complete`. + * Subclasses may override this method to perform additional actions or keep the response open for processing further requests. + */ + protected async onResponseComplete(request: ChatRequestModelImpl): Promise { + return request.response.complete(); + } + protected abstract addContentsToResponse(languageModelResponse: LanguageModelResponse, request: ChatRequestModelImpl): Promise; } @@ -313,20 +337,12 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { protected override async addContentsToResponse(languageModelResponse: LanguageModelResponse, request: ChatRequestModelImpl): Promise { if (isLanguageModelTextResponse(languageModelResponse)) { - const contents = this.parseContents(languageModelResponse.text); + const contents = this.parseContents(languageModelResponse.text, request); request.response.response.addContents(contents); - request.response.complete(); - if (this.defaultLogging) { - this.recordingService.recordResponse(ChatHistoryEntry.fromResponse(this.id, request)); - } return; } if (isLanguageModelStreamResponse(languageModelResponse)) { await this.addStreamResponse(languageModelResponse, request); - request.response.complete(); - if (this.defaultLogging) { - this.recordingService.recordResponse(ChatHistoryEntry.fromResponse(this.id, request)); - } return; } this.logger.error( @@ -341,7 +357,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { protected async addStreamResponse(languageModelResponse: LanguageModelStreamResponse, request: ChatRequestModelImpl): Promise { for await (const token of languageModelResponse.stream) { - const newContents = this.parse(token, request.response.response.content); + const newContents = this.parse(token, request); if (isArray(newContents)) { request.response.response.addContents(newContents); } else { @@ -357,7 +373,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { return; } - const result: ChatResponseContent[] = findFirstMatch(this.contentMatchers, text) ? this.parseContents(text) : []; + const result: ChatResponseContent[] = findFirstMatch(this.contentMatchers, text) ? this.parseContents(text, request) : []; if (result.length > 0) { request.response.response.addContents(result); } else { @@ -366,11 +382,11 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { } } - protected parse(token: LanguageModelStreamResponsePart, previousContent: ChatResponseContent[]): ChatResponseContent | ChatResponseContent[] { + protected parse(token: LanguageModelStreamResponsePart, request: ChatRequestModelImpl): ChatResponseContent | ChatResponseContent[] { const content = token.content; // eslint-disable-next-line no-null/no-null if (content !== undefined && content !== null) { - return this.defaultContentFactory.create(content); + return this.defaultContentFactory.create(content, request); } const toolCalls = token.tool_calls; if (toolCalls !== undefined) { @@ -378,7 +394,7 @@ export abstract class AbstractStreamParsingChatAgent extends AbstractChatAgent { new ToolCallChatResponseContentImpl(toolCall.id, toolCall.function?.name, toolCall.function?.arguments, toolCall.finished, toolCall.result)); return toolCallContents; } - return this.defaultContentFactory.create(''); + return this.defaultContentFactory.create('', request); } } diff --git a/packages/ai-chat/src/common/chat-model-util.ts b/packages/ai-chat/src/common/chat-model-util.ts new file mode 100644 index 0000000000000..1bad8b6bad0c6 --- /dev/null +++ b/packages/ai-chat/src/common/chat-model-util.ts @@ -0,0 +1,44 @@ +// ***************************************************************************** +// Copyright (C) 2024 EclipseSource GmbH. +// +// This program and the accompanying materials are made available under the +// terms of the Eclipse Public License v. 2.0 which is available at +// http://www.eclipse.org/legal/epl-2.0. +// +// This Source Code may also be made available under the following Secondary +// Licenses when the conditions for such availability set forth in the Eclipse +// Public License v. 2.0 are satisfied: GNU General Public License, version 2 +// with the GNU Classpath Exception which is available at +// https://www.gnu.org/software/classpath/license.html. +// +// SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 +// ***************************************************************************** +import { ChatProgressMessage, ChatRequestModel, ChatResponse, ChatResponseContent, ChatResponseModel, QuestionResponseContent } from './chat-model'; + +export function lastResponseContent(request: ChatRequestModel): ChatResponseContent | undefined { + return lastContentOfResponse(request.response?.response); +} + +export function lastContentOfResponse(response: ChatResponse | undefined): ChatResponseContent | undefined { + const content = response?.content; + return content && content.length > 0 ? content[content.length - 1] : undefined; +} + +export function lastProgressMessage(request: ChatRequestModel): ChatProgressMessage | undefined { + return lastProgressMessageOfResponse(request.response); +} + +export function lastProgressMessageOfResponse(response: ChatResponseModel | undefined): ChatProgressMessage | undefined { + const progressMessages = response?.progressMessages; + return progressMessages && progressMessages.length > 0 ? progressMessages[progressMessages.length - 1] : undefined; +} + +export function unansweredQuestions(request: ChatRequestModel): QuestionResponseContent[] { + const response = request.response; + return unansweredQuestionsOfResponse(response); +} + +function unansweredQuestionsOfResponse(response: ChatResponseModel | undefined): QuestionResponseContent[] { + if (!response || !response.response) { return []; } + return response.response.content.filter((c): c is QuestionResponseContent => QuestionResponseContent.is(c) && c.selectedOption === undefined); +} diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 0decfb284da35..404c3d11a41b9 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -80,6 +80,7 @@ export interface ChatProgressMessage { kind: 'progressMessage'; id: string; status: 'inProgress' | 'completed' | 'failed'; + show: 'untilFirstContent' | 'whileIncomplete' | 'forever'; content: string; } @@ -279,6 +280,44 @@ export namespace ErrorChatResponseContent { } } +export type QuestionResponseHandler = ( + selectedOption: { text: string, value?: string }, + request: ChatRequestModelImpl +) => void; + +export interface QuestionResponseContent extends ChatResponseContent { + kind: 'question'; + question: string; + options: { text: string, value?: string }[]; + selectedOption?: { text: string, value?: string }; + handler: QuestionResponseHandler; + request: ChatRequestModelImpl; +} + +export namespace QuestionResponseContent { + export function is(obj: unknown): obj is QuestionResponseContent { + return ( + ChatResponseContent.is(obj) && + obj.kind === 'question' && + 'question' in obj && + typeof (obj as { question: unknown }).question === 'string' && + 'options' in obj && + Array.isArray((obj as { options: unknown }).options) && + (obj as { options: unknown[] }).options.every(option => + typeof option === 'object' && + // eslint-disable-next-line no-null/no-null + option !== null && 'text' in option && + typeof (option as { text: unknown }).text === 'string' && + ('value' in option ? typeof (option as { value: unknown }).value === 'string' || typeof (option as { value: unknown }).value === 'undefined' : true) + ) && + 'handler' in obj && + typeof (obj as { handler: unknown }).handler === 'function' && + 'request' in obj && + obj.request instanceof ChatRequestModelImpl + ); + } +} + export interface ChatResponse { readonly content: ChatResponseContent[]; asString(): string; @@ -292,6 +331,7 @@ export interface ChatResponseModel { readonly response: ChatResponse; readonly isComplete: boolean; readonly isCanceled: boolean; + readonly isWaitingForInput: boolean; readonly isError: boolean; readonly agentId?: string readonly errorObject?: Error; @@ -688,6 +728,7 @@ class ChatResponseModelImpl implements ChatResponseModel { protected _response: ChatResponseImpl; protected _isComplete: boolean; protected _isCanceled: boolean; + protected _isWaitingForInput: boolean; protected _agentId?: string; protected _isError: boolean; protected _errorObject: Error | undefined; @@ -702,6 +743,7 @@ class ChatResponseModelImpl implements ChatResponseModel { this._response = response; this._isComplete = false; this._isCanceled = false; + this._isWaitingForInput = false; this._agentId = agentId; } @@ -728,6 +770,7 @@ class ChatResponseModelImpl implements ChatResponseModel { kind: 'progressMessage', id, status: message.status ?? 'inProgress', + show: message.show ?? 'untilFirstContent', ...message, }; this._progressMessages.push(newMessage); @@ -759,6 +802,10 @@ class ChatResponseModelImpl implements ChatResponseModel { return this._isCanceled; } + get isWaitingForInput(): boolean { + return this._isWaitingForInput; + } + get agentId(): string | undefined { return this._agentId; } @@ -769,17 +816,31 @@ class ChatResponseModelImpl implements ChatResponseModel { complete(): void { this._isComplete = true; + this._isWaitingForInput = false; this._onDidChangeEmitter.fire(); } cancel(): void { this._isComplete = true; this._isCanceled = true; + this._isWaitingForInput = false; + this._onDidChangeEmitter.fire(); + } + + waitForInput(): void { + this._isWaitingForInput = true; this._onDidChangeEmitter.fire(); } + + continue(): void { + this._isWaitingForInput = false; + this._onDidChangeEmitter.fire(); + } + error(error: Error): void { this._isComplete = true; this._isCanceled = false; + this._isWaitingForInput = false; this._isError = true; this._errorObject = error; this._onDidChangeEmitter.fire(); diff --git a/packages/ai-chat/src/common/index.ts b/packages/ai-chat/src/common/index.ts index cf160ddcadf10..b0100cff31203 100644 --- a/packages/ai-chat/src/common/index.ts +++ b/packages/ai-chat/src/common/index.ts @@ -16,6 +16,7 @@ export * from './chat-agents'; export * from './chat-agent-service'; export * from './chat-model'; +export * from './chat-model-util'; export * from './chat-request-parser'; export * from './chat-service'; export * from './command-chat-agents'; diff --git a/packages/ai-chat/src/common/parse-contents.spec.ts b/packages/ai-chat/src/common/parse-contents.spec.ts index c0a009f8cb814..cba9fa1b598e6 100644 --- a/packages/ai-chat/src/common/parse-contents.spec.ts +++ b/packages/ai-chat/src/common/parse-contents.spec.ts @@ -15,7 +15,7 @@ // ***************************************************************************** import { expect } from 'chai'; -import { ChatResponseContent, CodeChatResponseContentImpl, MarkdownChatResponseContentImpl } from './chat-model'; +import { ChatRequestModelImpl, ChatResponseContent, CodeChatResponseContentImpl, MarkdownChatResponseContentImpl } from './chat-model'; import { parseContents } from './parse-contents'; import { CodeContentMatcher, ResponseContentMatcher } from './response-content-matcher'; @@ -33,22 +33,24 @@ export const CommandContentMatcher: ResponseContentMatcher = { } }; +const fakeRequest = {} as ChatRequestModelImpl; + describe('parseContents', () => { it('should parse code content', () => { const text = '```typescript\nconsole.log("Hello World");\n```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript')]); }); it('should parse markdown content', () => { const text = 'Hello **World**'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('Hello **World**')]); }); it('should parse multiple content blocks', () => { const text = '```typescript\nconsole.log("Hello World");\n```\nHello **World**'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'), new MarkdownChatResponseContentImpl('\nHello **World**') @@ -57,7 +59,7 @@ describe('parseContents', () => { it('should parse multiple content blocks with different languages', () => { const text = '```typescript\nconsole.log("Hello World");\n```\n```python\nprint("Hello World")\n```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'), new CodeChatResponseContentImpl('print("Hello World")', 'python') @@ -66,7 +68,7 @@ describe('parseContents', () => { it('should parse multiple content blocks with different languages and markdown', () => { const text = '```typescript\nconsole.log("Hello World");\n```\nHello **World**\n```python\nprint("Hello World")\n```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'), new MarkdownChatResponseContentImpl('\nHello **World**\n'), @@ -76,7 +78,7 @@ describe('parseContents', () => { it('should parse content blocks with empty content', () => { const text = '```typescript\n```\nHello **World**\n```python\nprint("Hello World")\n```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new CodeChatResponseContentImpl('', 'typescript'), new MarkdownChatResponseContentImpl('\nHello **World**\n'), @@ -86,7 +88,7 @@ describe('parseContents', () => { it('should parse content with markdown, code, and markdown', () => { const text = 'Hello **World**\n```typescript\nconsole.log("Hello World");\n```\nGoodbye **World**'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new MarkdownChatResponseContentImpl('Hello **World**\n'), new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'), @@ -96,25 +98,25 @@ describe('parseContents', () => { it('should handle text with no special content', () => { const text = 'Just some plain text.'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('Just some plain text.')]); }); it('should handle text with only start code block', () => { const text = '```typescript\nconsole.log("Hello World");'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('```typescript\nconsole.log("Hello World");')]); }); it('should handle text with only end code block', () => { const text = 'console.log("Hello World");\n```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([new MarkdownChatResponseContentImpl('console.log("Hello World");\n```')]); }); it('should handle text with unmatched code block', () => { const text = '```typescript\nconsole.log("Hello World");\n```\n```python\nprint("Hello World")'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new CodeChatResponseContentImpl('console.log("Hello World");', 'typescript'), new MarkdownChatResponseContentImpl('\n```python\nprint("Hello World")') @@ -123,7 +125,7 @@ describe('parseContents', () => { it('should parse code block without newline after language', () => { const text = '```typescript console.log("Hello World");```'; - const result = parseContents(text); + const result = parseContents(text, fakeRequest); expect(result).to.deep.equal([ new MarkdownChatResponseContentImpl('```typescript console.log("Hello World");```') ]); @@ -131,7 +133,7 @@ describe('parseContents', () => { it('should parse with matches of multiple different matchers and default', () => { const text = '\nMY_SPECIAL_COMMAND\n\nHello **World**\n```python\nprint("Hello World")\n```\n\nMY_SPECIAL_COMMAND2\n'; - const result = parseContents(text, [CodeContentMatcher, CommandContentMatcher]); + const result = parseContents(text, fakeRequest, [CodeContentMatcher, CommandContentMatcher]); expect(result).to.deep.equal([ new CommandChatResponseContentImpl('MY_SPECIAL_COMMAND'), new MarkdownChatResponseContentImpl('\nHello **World**\n'), diff --git a/packages/ai-chat/src/common/parse-contents.ts b/packages/ai-chat/src/common/parse-contents.ts index 16f405495ce20..1dd1afbbe1ee8 100644 --- a/packages/ai-chat/src/common/parse-contents.ts +++ b/packages/ai-chat/src/common/parse-contents.ts @@ -13,7 +13,7 @@ * * SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 */ -import { ChatResponseContent } from './chat-model'; +import { ChatRequestModelImpl, ChatResponseContent } from './chat-model'; import { CodeContentMatcher, MarkdownContentFactory, ResponseContentFactory, ResponseContentMatcher } from './response-content-matcher'; interface Match { @@ -24,6 +24,7 @@ interface Match { export function parseContents( text: string, + request: ChatRequestModelImpl, contentMatchers: ResponseContentMatcher[] = [CodeContentMatcher], defaultContentFactory: ResponseContentFactory = MarkdownContentFactory ): ChatResponseContent[] { @@ -36,7 +37,7 @@ export function parseContents( if (!match) { // Add the remaining text as default content if (remainingText.length > 0) { - result.push(defaultContentFactory(remainingText)); + result.push(defaultContentFactory(remainingText, request)); } break; } @@ -45,11 +46,11 @@ export function parseContents( if (match.index > 0) { const precedingContent = remainingText.substring(0, match.index); if (precedingContent.trim().length > 0) { - result.push(defaultContentFactory(precedingContent)); + result.push(defaultContentFactory(precedingContent, request)); } } // 2. Add the matched content object - result.push(match.matcher.contentFactory(match.content)); + result.push(match.matcher.contentFactory(match.content, request)); // Update currentIndex to the end of the end of the match // And continue with the search after the end of the match currentIndex += match.index + match.content.length; diff --git a/packages/ai-chat/src/common/response-content-matcher.ts b/packages/ai-chat/src/common/response-content-matcher.ts index 3fb785e603c5f..86aa7e83316cb 100644 --- a/packages/ai-chat/src/common/response-content-matcher.ts +++ b/packages/ai-chat/src/common/response-content-matcher.ts @@ -14,13 +14,14 @@ * SPDX-License-Identifier: EPL-2.0 OR GPL-2.0-only WITH Classpath-exception-2.0 */ import { + ChatRequestModelImpl, ChatResponseContent, CodeChatResponseContentImpl, MarkdownChatResponseContentImpl } from './chat-model'; import { injectable } from '@theia/core/shared/inversify'; -export type ResponseContentFactory = (content: string) => ChatResponseContent; +export type ResponseContentFactory = (content: string, request: ChatRequestModelImpl) => ChatResponseContent; export const MarkdownContentFactory: ResponseContentFactory = (content: string) => new MarkdownChatResponseContentImpl(content); @@ -33,8 +34,8 @@ export const MarkdownContentFactory: ResponseContentFactory = (content: string) */ @injectable() export class DefaultResponseContentFactory { - create(content: string): ChatResponseContent { - return MarkdownContentFactory(content); + create(content: string, request: ChatRequestModelImpl): ChatResponseContent { + return MarkdownContentFactory(content, request); } } From 373ab178722b616da79d7aab017056aa81bb24ce Mon Sep 17 00:00:00 2001 From: Stefan Dirix Date: Tue, 26 Nov 2024 11:59:30 +0100 Subject: [PATCH 2/3] review adaptations --- ...sk-and-continue-chat-agent-contribution.ts | 105 +++++++++++++++--- .../question-part-renderer.tsx | 2 +- packages/ai-chat/src/common/chat-agents.ts | 10 +- packages/ai-chat/src/common/chat-model.ts | 34 +++++- 4 files changed, 121 insertions(+), 30 deletions(-) diff --git a/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts index 691e78f894db9..77da6a6704c13 100644 --- a/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts +++ b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts @@ -17,14 +17,16 @@ import { AbstractStreamParsingChatAgent, ChatAgent, + ChatMessage, + ChatModel, ChatRequestModelImpl, lastProgressMessage, - QuestionResponseContent, + QuestionResponseContentImpl, SystemMessageDescription, unansweredQuestions } from '@theia/ai-chat'; import { Agent, PromptTemplate } from '@theia/ai-core'; -import { injectable, interfaces } from '@theia/core/shared/inversify'; +import { injectable, interfaces, postConstruct } from '@theia/core/shared/inversify'; export function bindAskAndContinueChatAgentContribution(bind: interfaces.Bind): void { bind(AskAndContinueChatAgent).toSelf().inSingletonScope(); @@ -35,8 +37,13 @@ export function bindAskAndContinueChatAgentContribution(bind: interfaces.Bind): const systemPrompt: PromptTemplate = { id: 'askAndContinue-system', template: ` -Whatever the user inputs, you will write one arbitrary sentence and then ask a question with -the following format and two or three options: +You are an agent demonstrating on how to generate questions and continuing the conversation based on the user's answers. + +First answer the user's question or continue their story. +Then come up with an interesting question and 2-3 answers which will be presented to the user as multiple choice. + +Use the following format exactly to define the questions and answers. +Especially add the and tags around the JSON. { @@ -51,36 +58,80 @@ the following format and two or three options: ] } - ` + +Examples: + + +{ + "question": "What is the capital of France?", + "options": [ + { + "text": "Paris" + }, + { + "text": "Lyon" + } + ] +} + + + +{ + "question": "What does the fox say?", + "options": [ + { + "text": "Ring-ding-ding-ding-dingeringeding!" + }, + { + "text": "Wa-pa-pa-pa-pa-pa-pow!" + } + ] +} + + +The user will answer the question and you can continue the conversation. +Once they answered, the question will be replaced with a simple "Question/Answer" pair, for example + +Question: What does the fox say? +Answer: Ring-ding-ding-ding-dingeringeding! + +If the user did not answer the question, it will be marked with "No answer", for example + +Question: What is the capital of France? +No answer + +Do not generate such pairs yourself, instead treat them as a signal for a past question. +Do not ask further questions once the text contains 5 or more "Question/Answer" pairs. +` }; +/** + * This is a very simple example agent that asks questions and continues the conversation based on the user's answers. + */ @injectable() export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent implements ChatAgent { override id = 'AskAndContinue'; readonly name = 'AskAndContinue'; override defaultLanguageModelPurpose = 'chat'; - readonly description = 'What ever you input, this chat will ask a question and continues after that.'; + readonly description = 'This chat will ask questions related to the input and continues after that.'; readonly variables = []; readonly agentSpecificVariables = []; readonly functions = []; - override additionalContentMatchers = [ - { + @postConstruct() + addContentMatchers(): void { + this.contentMatchers.push({ start: /^.*$/m, end: /^<\/question>$/m, contentFactory: (content: string, request: ChatRequestModelImpl) => { const question = content.replace(/^\n|<\/question>$/g, ''); const parsedQuestion = JSON.parse(question); - return { - kind: 'question', - question: parsedQuestion.question, - options: parsedQuestion.options, - request, - handler: (option, _request) => this.handleAnswer(option, _request) - }; + return new QuestionResponseContentImpl(parsedQuestion.question, parsedQuestion.options, request, selectedOption => { + this.handleAnswer(selectedOption, request); + }); } - } - ]; + }); + } override languageModelRequirements = [ { @@ -111,7 +162,27 @@ export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent impl request.response.updateProgressMessage({ ...progressMessage, show: 'untilFirstContent', status: 'completed' }); } request.response.continue(); + // We're reusing the original request here as a shortcut. In combination with the override of 'getMessages' we continue generating. + // In a real-world scenario, you would likely create a new request here. this.invoke(request); } + + /** + * As the question/answer are handled within the same response, we add an additional user message at the end to indicate to + * the LLM to continue generating. + */ + protected override async getMessages(model: ChatModel): Promise { + const messages = await super.getMessages(model, true); + const requests = model.getRequests(); + if (!requests[requests.length - 1].response.isComplete && requests[requests.length - 1].response.response?.content.length > 0) { + return [...messages, + { + type: 'text', + actor: 'user', + query: 'Continue generating based on the user\'s answer or finish the conversation if 5 or more questions were already answered.' + }]; + } + return messages; + } } diff --git a/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx b/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx index 6af0a065392f1..58d65d3ebb725 100644 --- a/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx +++ b/packages/ai-chat-ui/src/browser/chat-response-renderer/question-part-renderer.tsx @@ -42,7 +42,7 @@ export class QuestionPartRenderer className={`theia-button theia-QuestionPartRenderer-option ${question.selectedOption === option ? 'selected' : ''}`} onClick={() => { question.selectedOption = option; - question.handler(option, question.request); + question.handler(option); }} disabled={question.selectedOption !== undefined || !node.response.isWaitingForInput} key={index} diff --git a/packages/ai-chat/src/common/chat-agents.ts b/packages/ai-chat/src/common/chat-agents.ts index 9dcf9a029808f..542cbfc462c4b 100644 --- a/packages/ai-chat/src/common/chat-agents.ts +++ b/packages/ai-chat/src/common/chat-agents.ts @@ -128,11 +128,6 @@ export abstract class AbstractChatAgent { @inject(ContributionProvider) @named(ResponseContentMatcherProvider) protected contentMatcherProviders: ContributionProvider; protected contentMatchers: ResponseContentMatcher[] = []; - /** - * Agent-specific content matchers used by this agent in addition to the contributed content matchers. - * @see ResponseContentMatcherProvider - */ - protected additionalContentMatchers: ResponseContentMatcher[] = []; @inject(DefaultResponseContentFactory) protected defaultContentFactory: DefaultResponseContentFactory; @@ -154,10 +149,7 @@ export abstract class AbstractChatAgent { protected initializeContentMatchers(): void { const contributedContentMatchers = this.contentMatcherProviders.getContributions().flatMap(provider => provider.matchers); - this.contentMatchers = [ - ...contributedContentMatchers, - ...this.additionalContentMatchers - ]; + this.contentMatchers.push(...contributedContentMatchers); } async invoke(request: ChatRequestModelImpl): Promise { diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 404c3d11a41b9..69341ccf34590 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -282,7 +282,6 @@ export namespace ErrorChatResponseContent { export type QuestionResponseHandler = ( selectedOption: { text: string, value?: string }, - request: ChatRequestModelImpl ) => void; export interface QuestionResponseContent extends ChatResponseContent { @@ -305,8 +304,7 @@ export namespace QuestionResponseContent { Array.isArray((obj as { options: unknown }).options) && (obj as { options: unknown[] }).options.every(option => typeof option === 'object' && - // eslint-disable-next-line no-null/no-null - option !== null && 'text' in option && + option && 'text' in option && typeof (option as { text: unknown }).text === 'string' && ('value' in option ? typeof (option as { value: unknown }).value === 'string' || typeof (option as { value: unknown }).value === 'undefined' : true) ) && @@ -642,6 +640,31 @@ export class HorizontalLayoutChatResponseContentImpl implements HorizontalLayout } } +/** + * Default implementation for the QuestionResponseContent. + */ +export class QuestionResponseContentImpl implements QuestionResponseContent { + readonly kind = 'question'; + protected _selectedOption: { text: string; value?: string } | undefined; + constructor(public question: string, public options: { text: string, value?: string }[], + public request: ChatRequestModelImpl, public handler: QuestionResponseHandler) { + } + set selectedOption(option: { text: string; value?: string; } | undefined) { + this._selectedOption = option; + this.request.response.response.responseContentChanged(); + } + get selectedOption(): { text: string; value?: string; } | undefined { + return this._selectedOption; + } + asString?(): string | undefined { + return `Question: ${this.question} +${this.selectedOption ? `Answer: ${this.selectedOption?.text}` : 'No answer'}`; + } + merge?(): boolean { + return false; + } +} + class ChatResponseImpl implements ChatResponse { protected readonly _onDidChangeEmitter = new Emitter(); onDidChange: Event = this._onDidChangeEmitter.event; @@ -694,6 +717,11 @@ class ChatResponseImpl implements ChatResponse { this._updateResponseRepresentation(); } + responseContentChanged(): void { + this._updateResponseRepresentation(); + this._onDidChangeEmitter.fire(); + } + protected _updateResponseRepresentation(): void { this._responseRepresentation = this._content .map(responseContent => { From cd54fece6358157c042cc5bca31e2a30b1d9fd64 Mon Sep 17 00:00:00 2001 From: Stefan Dirix Date: Tue, 26 Nov 2024 12:08:32 +0100 Subject: [PATCH 3/3] further review changes --- .../browser/chat/ask-and-continue-chat-agent-contribution.ts | 4 ++-- packages/ai-chat/src/common/chat-model.ts | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts index 77da6a6704c13..c571d22facaf2 100644 --- a/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts +++ b/examples/api-samples/src/browser/chat/ask-and-continue-chat-agent-contribution.ts @@ -161,9 +161,9 @@ export class AskAndContinueChatAgent extends AbstractStreamParsingChatAgent impl if (progressMessage) { request.response.updateProgressMessage({ ...progressMessage, show: 'untilFirstContent', status: 'completed' }); } - request.response.continue(); + request.response.stopWaitingForInput(); // We're reusing the original request here as a shortcut. In combination with the override of 'getMessages' we continue generating. - // In a real-world scenario, you would likely create a new request here. + // In a real-world scenario, you would likely manually interact with an LLM here to generate and append the next response. this.invoke(request); } diff --git a/packages/ai-chat/src/common/chat-model.ts b/packages/ai-chat/src/common/chat-model.ts index 69341ccf34590..9c6cd66af14b4 100644 --- a/packages/ai-chat/src/common/chat-model.ts +++ b/packages/ai-chat/src/common/chat-model.ts @@ -860,7 +860,7 @@ class ChatResponseModelImpl implements ChatResponseModel { this._onDidChangeEmitter.fire(); } - continue(): void { + stopWaitingForInput(): void { this._isWaitingForInput = false; this._onDidChangeEmitter.fire(); }