-
Notifications
You must be signed in to change notification settings - Fork 8.6k
[onechat] add tool progression events #233724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
957db18
102f6af
8111e18
85a17d5
ef34480
5089269
fa9770e
2f5bcf8
2f88fe7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,6 +11,7 @@ import type { ConversationRound } from './conversation'; | |
|
|
||
| export enum ChatEventType { | ||
| toolCall = 'tool_call', | ||
| toolProgress = 'tool_progress', | ||
| toolResult = 'tool_result', | ||
| reasoning = 'reasoning', | ||
| messageChunk = 'message_chunk', | ||
|
|
@@ -39,6 +40,21 @@ export const isToolCallEvent = (event: OnechatEvent<string, any>): event is Tool | |
| return event.type === ChatEventType.toolCall; | ||
| }; | ||
|
|
||
| // Tool progress | ||
|
|
||
| export interface ToolProgressEventData { | ||
| tool_call_id: string; | ||
| message: string; | ||
| } | ||
|
Comment on lines
+45
to
+48
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The new We can make that evolve into something more structured later, but for now, simpler is better, and full text is gonna be fine ihmo. |
||
|
|
||
| export type ToolProgressEvent = ChatEventBase<ChatEventType.toolProgress, ToolProgressEventData>; | ||
|
|
||
| export const isToolProgressEvent = ( | ||
| event: OnechatEvent<string, any> | ||
| ): event is ToolProgressEvent => { | ||
| return event.type === ChatEventType.toolProgress; | ||
| }; | ||
|
|
||
| // Tool result | ||
|
|
||
| export interface ToolResultEventData { | ||
|
|
@@ -158,6 +174,7 @@ export const isConversationUpdatedEvent = ( | |
| */ | ||
| export type ChatAgentEvent = | ||
| | ToolCallEvent | ||
| | ToolProgressEvent | ||
| | ToolResultEvent | ||
| | ReasoningEvent | ||
| | MessageChunkEvent | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,7 +46,6 @@ describe('extractToolReturn', () => { | |
| }, | ||
| }, | ||
| ], | ||
| runId: 'unknown', | ||
| }); | ||
| }); | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,7 +9,16 @@ import type { StructuredTool } from '@langchain/core/tools'; | |
| import { tool as toTool } from '@langchain/core/tools'; | ||
| import type { Logger } from '@kbn/logging'; | ||
| import type { KibanaRequest } from '@kbn/core-http-server'; | ||
| import type { ToolProvider, ExecutableTool, RunToolReturn } from '@kbn/onechat-server'; | ||
| import type { ChatAgentEvent } from '@kbn/onechat-common'; | ||
| import { ChatEventType } from '@kbn/onechat-common'; | ||
| import type { | ||
| AgentEventEmitterFn, | ||
| ExecutableTool, | ||
| OnechatToolEvent, | ||
| RunToolReturn, | ||
| ToolProvider, | ||
| ToolEventHandlerFn, | ||
| } from '@kbn/onechat-server'; | ||
| import { ToolResultType } from '@kbn/onechat-common/tools/tool_result'; | ||
| import type { ToolCall } from './messages'; | ||
|
|
||
|
|
@@ -30,18 +39,20 @@ export const toolsToLangchain = async ({ | |
| request, | ||
| tools, | ||
| logger, | ||
| sendEvent, | ||
| }: { | ||
| request: KibanaRequest; | ||
| tools: ToolProvider | ExecutableTool[]; | ||
| logger: Logger; | ||
| sendEvent?: AgentEventEmitterFn; | ||
| }): Promise<ToolsAndMappings> => { | ||
| const allTools = Array.isArray(tools) ? tools : await tools.list({ request }); | ||
| const onechatToLangchainIdMap = createToolIdMappings(allTools); | ||
|
|
||
| const convertedTools = await Promise.all( | ||
| allTools.map((tool) => { | ||
| const toolId = onechatToLangchainIdMap.get(tool.id); | ||
| return toolToLangchain({ tool, logger, toolId }); | ||
| return toolToLangchain({ tool, logger, toolId, sendEvent }); | ||
| }) | ||
| ); | ||
|
|
||
|
|
@@ -83,25 +94,35 @@ export const toolToLangchain = ({ | |
| tool, | ||
| toolId, | ||
| logger, | ||
| sendEvent, | ||
| }: { | ||
| tool: ExecutableTool; | ||
| toolId?: string; | ||
| logger: Logger; | ||
| sendEvent?: AgentEventEmitterFn; | ||
| }): StructuredTool => { | ||
| return toTool( | ||
| async (input): Promise<[string, RunToolReturn]> => { | ||
| async (input, config): Promise<[string, RunToolReturn]> => { | ||
| let onEvent: ToolEventHandlerFn | undefined; | ||
| if (sendEvent) { | ||
| const toolCallId = config.configurable?.tool_call_id ?? config.toolCall?.id ?? 'unknown'; | ||
|
Comment on lines
+107
to
+108
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This one took time to find, but we do have the info of the tool_call_id within langchain wrappers, so we can "automatically" attach it to the underlying tool events. |
||
| const convertEvent = getToolEventConverter({ toolCallId }); | ||
| onEvent = (event) => { | ||
| sendEvent(convertEvent(event)); | ||
| }; | ||
| } | ||
|
|
||
| try { | ||
| logger.debug(`Calling tool ${tool.id} with params: ${JSON.stringify(input, null, 2)}`); | ||
| const toolReturn = await tool.execute({ toolParams: input }); | ||
| const content = JSON.stringify({ results: toolReturn.results }); // wrap in a results object to conform to bedrock format | ||
| const toolReturn = await tool.execute({ toolParams: input, onEvent }); | ||
| const content = JSON.stringify({ results: toolReturn.results }); | ||
| logger.debug(`Tool ${tool.id} returned reply of length ${content.length}`); | ||
| return [content, toolReturn]; | ||
| } catch (e) { | ||
| logger.warn(`error calling tool ${tool.id}: ${e}`); | ||
| logger.debug(e.stack); | ||
|
|
||
| const errorToolReturn: RunToolReturn = { | ||
| runId: tool.id, | ||
| results: [ | ||
| { | ||
| type: ToolResultType.error, | ||
|
|
@@ -141,3 +162,18 @@ function reverseMap<K, V>(map: Map<K, V>): Map<V, K> { | |
| } | ||
| return reversed; | ||
| } | ||
|
|
||
| const getToolEventConverter = ({ toolCallId }: { toolCallId: string }) => { | ||
| return (toolEvent: OnechatToolEvent): ChatAgentEvent => { | ||
| if (toolEvent.type === ChatEventType.toolProgress) { | ||
| return { | ||
| type: ChatEventType.toolProgress, | ||
| data: { | ||
| ...toolEvent.data, | ||
| tool_call_id: toolCallId, | ||
| }, | ||
| }; | ||
| } | ||
| throw new Error(`Invalid tool call type ${toolEvent.type}`); | ||
| }; | ||
| }; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ import type { BaseMessage } from '@langchain/core/messages'; | |
| import { isToolMessage } from '@langchain/core/messages'; | ||
| import { messagesStateReducer } from '@langchain/langgraph'; | ||
| import { ToolNode } from '@langchain/langgraph/prebuilt'; | ||
| import type { ScopedModel } from '@kbn/onechat-server'; | ||
| import type { ScopedModel, ToolEventEmitter } from '@kbn/onechat-server'; | ||
| import type { ElasticsearchClient, Logger } from '@kbn/core/server'; | ||
| import type { ToolResult } from '@kbn/onechat-common/tools'; | ||
| import { ToolResultType } from '@kbn/onechat-common/tools'; | ||
|
|
@@ -19,6 +19,7 @@ import { indexExplorer } from '../index_explorer'; | |
| import { createNaturalLanguageSearchTool, createRelevanceSearchTool } from './inner_tools'; | ||
| import { getSearchPrompt } from './prompts'; | ||
| import type { SearchTarget } from './types'; | ||
| import { progressMessages } from './i18n'; | ||
|
|
||
| const StateAnnotation = Annotation.Root({ | ||
| // inputs | ||
|
|
@@ -45,19 +46,23 @@ export const createSearchToolGraph = ({ | |
| model, | ||
| esClient, | ||
| logger, | ||
| events, | ||
| }: { | ||
| model: ScopedModel; | ||
| esClient: ElasticsearchClient; | ||
| logger: Logger; | ||
| events?: ToolEventEmitter; | ||
| }) => { | ||
| const tools = [ | ||
| createRelevanceSearchTool({ model, esClient }), | ||
| createNaturalLanguageSearchTool({ model, esClient }), | ||
| createRelevanceSearchTool({ model, esClient, events }), | ||
| createNaturalLanguageSearchTool({ model, esClient, events }), | ||
| ]; | ||
|
|
||
| const toolNode = new ToolNode<typeof StateAnnotation.State.messages>(tools); | ||
|
|
||
| const selectAndValidateIndex = async (state: StateType) => { | ||
| events?.reportProgress(progressMessages.selectingTarget()); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Example of reporting progress for the search tool |
||
|
|
||
| const explorerRes = await indexExplorer({ | ||
| nlQuery: state.nlQuery, | ||
| indexPattern: state.targetPattern ?? '*', | ||
|
|
@@ -69,6 +74,8 @@ export const createSearchToolGraph = ({ | |
|
|
||
| if (explorerRes.resources.length > 0) { | ||
| const selectedResource = explorerRes.resources[0]; | ||
| events?.reportProgress(progressMessages.selectedTarget(selectedResource.name)); | ||
|
|
||
| return { | ||
| indexIsValid: true, | ||
| searchTarget: { type: selectedResource.type, name: selectedResource.name }, | ||
|
|
@@ -90,6 +97,7 @@ export const createSearchToolGraph = ({ | |
| }); | ||
|
|
||
| const callSearchAgent = async (state: StateType) => { | ||
| events?.reportProgress(progressMessages.resolvingSearchStrategy()); | ||
| const response = await searchModel.invoke( | ||
| getSearchPrompt({ nlQuery: state.nlQuery, searchTarget: state.searchTarget }) | ||
| ); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| /* | ||
| * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
| * or more contributor license agreements. Licensed under the Elastic License | ||
| * 2.0; you may not use this file except in compliance with the Elastic License | ||
| * 2.0. | ||
| */ | ||
|
|
||
| import { i18n } from '@kbn/i18n'; | ||
|
|
||
| export const progressMessages = { | ||
| selectingTarget: () => { | ||
| return i18n.translate('xpack.onechat.tools.search.progress.selectingTarget', { | ||
| defaultMessage: 'Selecting the best target for this query', | ||
| }); | ||
| }, | ||
| selectedTarget: (target: string) => { | ||
| return i18n.translate('xpack.onechat.tools.search.progress.selectedTarget', { | ||
| defaultMessage: 'Selected "{target}" as the next search target', | ||
| values: { | ||
| target, | ||
| }, | ||
| }); | ||
| }, | ||
| resolvingSearchStrategy: () => { | ||
| return i18n.translate('xpack.onechat.tools.search.progress.searchStrategy', { | ||
| defaultMessage: 'Thinking about the search strategy to use', | ||
| }); | ||
| }, | ||
| performingRelevanceSearch: ({ term }: { term: string }) => { | ||
| return i18n.translate('xpack.onechat.tools.search.progress.performingRelevanceSearch', { | ||
| defaultMessage: 'Searching documents for "{term}"', | ||
| values: { | ||
| term, | ||
| }, | ||
| }); | ||
| }, | ||
| performingNlSearch: ({ query }: { query: string }) => { | ||
| return i18n.translate('xpack.onechat.tools.search.progress.performingTextSearch', { | ||
| defaultMessage: 'Generating an ES|QL for "{query}"', | ||
| values: { | ||
| query, | ||
| }, | ||
| }); | ||
| }, | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of the scope of the issue, but I took the opportunity to remove the arguably not useful generics around conversation round types.