Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,28 @@ export enum AgentType {
conversational = 'conversational',
}

/**
* Execution mode for agents.
*/
export enum AgentMode {
/**
* Normal (Q/A) mode
*/
normal = 'normal',
/**
* "Think more" mode
*/
reason = 'reason',
/**
* "Plan-and-execute" mode
*/
plan = 'plan',
/**
* "Deep-research" mode
*/
research = 'research',
}

/**
* ID of the onechat default conversational agent
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

export {
AgentType,
AgentMode,
OneChatDefaultAgentId,
OneChatDefaultAgentProviderId,
type AgentDescriptor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ export interface AssistantResponse {
message: string;
}

export enum ConversationRoundStepType {
toolCall = 'toolCall',
reasoning = 'reasoning',
}

// tool call step

export type ConversationRoundStepMixin<TType extends ConversationRoundStepType, TData> = TData & {
type: TType;
};

/**
* Represents a tool call with the corresponding result.
*/
Expand All @@ -51,14 +62,6 @@ export interface ToolCallWithResult {
result: string;
}

export enum ConversationRoundStepType {
toolCall = 'toolCall',
}

export type ConversationRoundStepMixin<TType extends ConversationRoundStepType, TData> = TData & {
type: TType;
};

export type ToolCallStep = ConversationRoundStepMixin<
ConversationRoundStepType.toolCall,
ToolCallWithResult
Expand All @@ -68,8 +71,26 @@ export const isToolCallStep = (step: ConversationRoundStep): step is ToolCallSte
return step.type === ConversationRoundStepType.toolCall;
};

// may have more type of steps later.
export type ConversationRoundStep = ToolCallStep;
// reasoning step

export interface ReasoningStepData {
/** plain text reasoning content */
reasoning: string;
}

export type ReasoningStep = ConversationRoundStepMixin<
ConversationRoundStepType.reasoning,
ReasoningStepData
>;

export const isReasoningStep = (step: ConversationRoundStep): step is ReasoningStep => {
return step.type === ConversationRoundStepType.reasoning;
};

/**
* Defines all possible types for round steps.
*/
export type ConversationRoundStep = ToolCallStep | ReasoningStep;

/**
* Represents a round in a conversation, containing all the information
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,14 @@ export {
type ToolCallWithResult,
type ConversationRound,
type Conversation,
type ConversationRoundStepMixin,
type ToolCallStep,
type ConversationRoundStep,
type ReasoningStepData,
type ReasoningStep,
ConversationRoundStepType,
isToolCallStep,
isReasoningStep,
} from './conversation';
export {
ChatEventType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ export {
OneChatDefaultAgentId,
OneChatDefaultAgentProviderId,
AgentType,
AgentMode,
type AgentDescriptor,
type AgentIdentifier,
type PlainIdAgentIdentifier,
Expand All @@ -66,11 +67,14 @@ export {
type MessageCompleteEvent,
type RoundCompleteEventData,
type RoundCompleteEvent,
type ReasoningEventData,
type ReasoningEvent,
isToolCallEvent,
isToolResultEvent,
isMessageChunkEvent,
isMessageCompleteEvent,
isRoundCompleteEvent,
isReasoningEvent,
isSerializedAgentIdentifier,
isPlainAgentIdentifier,
isStructuredAgentIdentifier,
Expand All @@ -94,6 +98,9 @@ export {
isConversationUpdatedEvent,
type ToolCallStep,
type ConversationRoundStep,
type ReasoningStepData,
type ReasoningStep,
ConversationRoundStepType,
isToolCallStep,
isReasoningStep,
} from './chat';
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import {
} from '@kbn/onechat-common/agents';
import { extractTextContent } from './messages';

export const isStreamEvent = (input: any): input is LangchainStreamEvent => {
return 'event' in input && 'name' in input;
};

export const matchGraphName = (event: LangchainStreamEvent, graphName: string): boolean => {
return event.metadata.graphName === graphName;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

export {
isStreamEvent,
matchGraphName,
matchGraphNode,
matchName,
Expand All @@ -15,4 +16,11 @@ export {
createMessageEvent,
createReasoningEvent,
} from './graph_events';
export { extractTextContent } from './messages';
export { extractTextContent, extractToolCalls, type ToolCall } from './messages';
export {
toolsToLangchain,
toolToLangchain,
toolIdentifierFromToolCall,
type ToolIdMapping,
type ToolsAndMappings,
} from './tools';
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/

import { BaseMessage, MessageContentComplex } from '@langchain/core/messages';
import { BaseMessage, MessageContentComplex, isAIMessage } from '@langchain/core/messages';

/**
* Extract the text content from a langchain message or chunk.
Expand All @@ -23,3 +23,30 @@ export const extractTextContent = (message: BaseMessage): string => {
return content;
}
};

export interface ToolCall {
toolCallId: string;
toolName: string;
args: Record<string, any>;
}

/**
* Extracts the tool calls from a message.
*/
export const extractToolCalls = (message: BaseMessage): ToolCall[] => {
if (isAIMessage(message)) {
return (
message.tool_calls?.map<ToolCall>((toolCall) => {
if (!toolCall.id) {
throw new Error('Tool call must have an id');
}
return {
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.args,
};
}) ?? []
);
}
return [];
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { StructuredTool, tool as toTool } from '@langchain/core/tools';
import { Logger } from '@kbn/logging';
import type { KibanaRequest } from '@kbn/core-http-server';
import {
toSerializedToolIdentifier,
type SerializedToolIdentifier,
type StructuredToolIdentifier,
toStructuredToolIdentifier,
unknownToolProviderId,
} from '@kbn/onechat-common';
import type { ToolProvider, ExecutableTool } from '@kbn/onechat-server';
import type { ToolCall } from './messages';

export type ToolIdMapping = Map<string, SerializedToolIdentifier>;

export interface ToolsAndMappings {
/**
* The tools in langchain format
*/
tools: StructuredTool[];
/**
* ID mapping that can be used to retrieve the full identifier from the langchain tool id.
*/
idMappings: ToolIdMapping;
}

export const toolsToLangchain = async ({
request,
tools,
logger,
}: {
request: KibanaRequest;
tools: ToolProvider | ExecutableTool[];
logger: Logger;
}): Promise<ToolsAndMappings> => {
const allTools = Array.isArray(tools) ? tools : await tools.list({ request });
const mappings = createToolIdMappings(allTools);

const reverseMappings = reverseMap(mappings);

const convertedTools = await Promise.all(
allTools.map((tool) => {
const toolId = reverseMappings.get(
toSerializedToolIdentifier({ toolId: tool.id, providerId: tool.meta.providerId })
);
return toolToLangchain({ tool, logger, toolId });
})
);

return {
tools: convertedTools,
idMappings: mappings,
};
};

export const createToolIdMappings = (tools: ExecutableTool[]): ToolIdMapping => {
const toolIds = new Set<string>();
const mapping: ToolIdMapping = new Map();

for (const tool of tools) {
let toolId = tool.id;
let index = 1;
while (toolIds.has(toolId)) {
toolId = `${toolId}_${index++}`;
}
toolIds.add(toolId);
mapping.set(
toolId,
toSerializedToolIdentifier({ toolId: tool.id, providerId: tool.meta.providerId })
);
}

return mapping;
};

export const toolToLangchain = ({
tool,
toolId,
logger,
}: {
tool: ExecutableTool;
toolId?: string;
logger: Logger;
}): StructuredTool => {
return toTool(
async (input) => {
try {
const toolReturn = await tool.execute({ toolParams: input });
return JSON.stringify(toolReturn.result);
} catch (e) {
logger.warn(`error calling tool ${tool.id}: ${e.message}`);
throw e;
}
},
{
name: toolId ?? tool.id,
description: tool.description,
schema: tool.schema,
metadata: {
serializedToolId: toSerializedToolIdentifier({
toolId: tool.id,
providerId: tool.meta.providerId,
}),
},
}
);
};

export const toolIdentifierFromToolCall = (
toolCall: ToolCall,
mapping: ToolIdMapping
): StructuredToolIdentifier => {
return toStructuredToolIdentifier(
mapping.get(toolCall.toolName) ?? {
toolId: toolCall.toolName,
providerId: unknownToolProviderId,
}
);
};

function reverseMap<K, V>(map: Map<K, V>): Map<V, K> {
const reversed = new Map<V, K>();
for (const [key, value] of map.entries()) {
if (reversed.has(value)) {
throw new Error(`Duplicate value detected while reversing map: ${value}`);
}
reversed.set(value, key);
}
return reversed;
}
Loading