From 569975a64910086039c79f7d1e16ffc396b20560 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Tue, 6 Aug 2024 13:28:58 -0600 Subject: [PATCH] done --- .../execute_custom_llm_chain/index.ts | 54 +++++++++++-------- .../graphs/default_assistant_graph/graph.ts | 10 ++-- .../graphs/default_assistant_graph/index.ts | 27 ++++++---- .../nodes/generate_chat_title.ts | 2 + .../default_assistant_graph/nodes/respond.ts | 4 +- .../graphs/default_assistant_graph/types.ts | 2 - 6 files changed, 55 insertions(+), 44 deletions(-) diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 38759d4d68ea3..0af5f0453ec8b 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -57,23 +57,32 @@ export const callAgentExecutor: AgentExecutor = async ({ const isOpenAI = llmType === 'openai'; const llmClass = getLlmClass(llmType, bedrockChatEnabled); - const llm = new llmClass({ - actionsClient, - connectorId, - llmType, - logger, - // possible client model override, - // let this be undefined otherwise so the connector handles the model - model: request.body.model, - // ensure this is defined because we default to it in the language_models - // This is where the LangSmith logs (Metadata > Invocation Params) are set - temperature: getDefaultArguments(llmType).temperature, - signal: abortSignal, - streaming: isStream, - // prevents the agent from retrying on failure - // failure could be due to bad connector, we should deliver that result to the client asap - maxRetries: 0, - }); + /** + * Creates a new instance of llmClass. + * + * This function ensures that a new llmClass instance is created every time it is called. + * This is necessary to avoid any potential side effects from shared state. By always + * creating a new instance, we prevent other uses of llm from binding and changing + * the state unintentionally. For this reason, never assign this value to a variable (ex const llm = createLlmInstance()) + */ + const createLlmInstance = () => + new llmClass({ + actionsClient, + connectorId, + llmType, + logger, + // possible client model override, + // let this be undefined otherwise so the connector handles the model + model: request.body.model, + // ensure this is defined because we default to it in the language_models + // This is where the LangSmith logs (Metadata > Invocation Params) are set + temperature: getDefaultArguments(llmType).temperature, + signal: abortSignal, + streaming: isStream, + // prevents the agent from retrying on failure + // failure could be due to bad connector, we should deliver that result to the client asap + maxRetries: 0, + }); const anonymizationFieldsRes = await dataClients?.anonymizationFieldsDataClient?.findDocuments({ @@ -99,7 +108,7 @@ export const callAgentExecutor: AgentExecutor = async ({ const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(createLlmInstance(), esStore.asRetriever(10)); // Fetch any applicable tools that the source plugin may have registered const assistantToolParams: AssistantToolParams = { @@ -108,7 +117,6 @@ export const callAgentExecutor: AgentExecutor = async ({ chain, esClient, isEnabledKnowledgeBase: true, - llm, logger, modelExists, onNewReplacements, @@ -118,7 +126,7 @@ export const callAgentExecutor: AgentExecutor = async ({ }; const tools: ToolInterface[] = assistantTools.flatMap( - (tool) => tool.getTool(assistantToolParams) ?? [] + (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? [] ); logger.debug( @@ -132,14 +140,14 @@ export const callAgentExecutor: AgentExecutor = async ({ }; // isOpenAI check is not on agentType alone because typescript doesn't like const executor = isOpenAI - ? await initializeAgentExecutorWithOptions(tools, llm, { + ? await initializeAgentExecutorWithOptions(tools, createLlmInstance(), { agentType: 'openai-functions', ...executorArgs, }) : llmType === 'bedrock' && bedrockChatEnabled ? new lcAgentExecutor({ agent: await createToolCallingAgent({ - llm, + llm: createLlmInstance(), tools, prompt: ChatPromptTemplate.fromMessages([ ['system', 'You are a helpful assistant'], @@ -151,7 +159,7 @@ export const callAgentExecutor: AgentExecutor = async ({ }), tools, }) - : await initializeAgentExecutorWithOptions(tools, llm, { + : await initializeAgentExecutorWithOptions(tools, createLlmInstance(), { agentType: 'structured-chat-zero-shot-react-description', ...executorArgs, returnIntermediateSteps: false, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts index 7ee0e6912b563..6eac3f1c98303 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/graph.ts @@ -41,8 +41,7 @@ interface GetDefaultAssistantGraphParams { agentRunnable: AgentRunnableSequence; dataClients?: AssistantDataClients; conversationId?: string; - getLlmInstance: () => BaseChatModel; - llm: BaseChatModel; + createLlmInstance: () => BaseChatModel; logger: Logger; tools: StructuredTool[]; responseLanguage: string; @@ -61,8 +60,7 @@ export const getDefaultAssistantGraph = ({ agentRunnable, conversationId, dataClients, - getLlmInstance, - llm, + createLlmInstance, logger, responseLanguage, tools, @@ -106,7 +104,6 @@ export const getDefaultAssistantGraph = ({ // Default node parameters const nodeParams: NodeParamsBase = { - model: llm, logger, }; @@ -131,6 +128,7 @@ export const getDefaultAssistantGraph = ({ const generateChatTitleNode = (state: AgentState) => generateChatTitle({ ...nodeParams, + model: createLlmInstance(), state, responseLanguage, }); @@ -154,7 +152,7 @@ export const getDefaultAssistantGraph = ({ const respondNode = (state: AgentState) => respond({ ...nodeParams, - llm: getLlmInstance(), + model: createLlmInstance(), state, }); const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state }); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts index 40336561149e6..758a3a757eb76 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/index.ts @@ -57,7 +57,16 @@ export const callAssistantGraph: AgentExecutor = async ({ const logger = parentLogger.get('defaultAssistantGraph'); const isOpenAI = llmType === 'openai'; const llmClass = getLlmClass(llmType, bedrockChatEnabled); - const getLlmInstance = () => + + /** + * Creates a new instance of llmClass. + * + * This function ensures that a new llmClass instance is created every time it is called. + * This is necessary to avoid any potential side effects from shared state. By always + * creating a new instance, we prevent other uses of llm from binding and changing + * the state unintentionally. For this reason, never assign this value to a variable (ex const llm = createLlmInstance()) + */ + const createLlmInstance = () => new llmClass({ actionsClient, connectorId, @@ -76,8 +85,6 @@ export const callAssistantGraph: AgentExecutor = async ({ maxRetries: 0, }); - const llm = getLlmInstance(); - const anonymizationFieldsRes = await dataClients?.anonymizationFieldsDataClient?.findDocuments({ perPage: 1000, @@ -93,7 +100,7 @@ export const callAssistantGraph: AgentExecutor = async ({ const modelExists = await esStore.isModelInstalled(); // Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now - const chain = RetrievalQAChain.fromLLM(getLlmInstance(), esStore.asRetriever(10)); + const chain = RetrievalQAChain.fromLLM(createLlmInstance(), esStore.asRetriever(10)); // Check if KB is available const isEnabledKnowledgeBase = (await dataClients?.kbDataClient?.isModelDeployed()) ?? false; @@ -106,7 +113,6 @@ export const callAssistantGraph: AgentExecutor = async ({ esClient, isEnabledKnowledgeBase, kbDataClient: dataClients?.kbDataClient, - llm, logger, modelExists, onNewReplacements, @@ -116,26 +122,26 @@ export const callAssistantGraph: AgentExecutor = async ({ }; const tools: StructuredTool[] = assistantTools.flatMap( - (tool) => tool.getTool(assistantToolParams) ?? [] + (tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? [] ); const agentRunnable = isOpenAI ? await createOpenAIFunctionsAgent({ - llm, + llm: createLlmInstance(), tools, prompt: openAIFunctionAgentPrompt, streamRunnable: isStream, }) : llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled ? await createToolCallingAgent({ - llm, + llm: createLlmInstance(), tools, prompt: llmType === 'bedrock' ? bedrockToolCallingAgentPrompt : geminiToolCallingAgentPrompt, streamRunnable: isStream, }) : await createStructuredChatAgent({ - llm, + llm: createLlmInstance(), tools, prompt: structuredChatAgentPrompt, streamRunnable: isStream, @@ -147,9 +153,8 @@ export const callAssistantGraph: AgentExecutor = async ({ agentRunnable, conversationId, dataClients, - llm, // we need to pass it like this or streaming does not work for bedrock - getLlmInstance, + createLlmInstance, logger, tools, responseLanguage, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts index 74cf0fca3929a..9cda33fdbabbc 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/generate_chat_title.ts @@ -7,6 +7,7 @@ import { StringOutputParser } from '@langchain/core/output_parsers'; import { ChatPromptTemplate } from '@langchain/core/prompts'; +import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import { AgentState, NodeParamsBase } from '../types'; export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) => @@ -25,6 +26,7 @@ export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) => export interface GenerateChatTitleParams extends NodeParamsBase { responseLanguage: string; state: AgentState; + model: BaseChatModel; } export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle'; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts index bb3b3a518e06d..7c11b96bbca0d 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/nodes/respond.ts @@ -11,7 +11,7 @@ import { AGENT_NODE_TAG } from './run_agent'; import { AgentState } from '../types'; export const RESPOND_NODE = 'respond'; -export const respond = async ({ llm, state }: { llm: BaseChatModel; state: AgentState }) => { +export const respond = async ({ model, state }: { model: BaseChatModel; state: AgentState }) => { if (state?.agentOutcome && 'returnValues' in state.agentOutcome) { const userMessage = [ 'user', @@ -21,7 +21,7 @@ export const respond = async ({ llm, state }: { llm: BaseChatModel; state: Agent Do not verify, confirm or anything else. Just reply with the same content as provided above.`, ] as [StringWithAutocomplete<'user'>, string]; - const responseMessage = await llm + const responseMessage = await model // use AGENT_NODE_TAG to identify as agent node for stream parsing .withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] }) .invoke([userMessage]); diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts index 4ee4f1ba1b148..5d86a0f6b97ed 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/graphs/default_assistant_graph/types.ts @@ -7,7 +7,6 @@ import { BaseMessage } from '@langchain/core/messages'; import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents'; -import { BaseChatModel } from '@langchain/core/language_models/chat_models'; import type { Logger } from '@kbn/logging'; import { ConversationResponse } from '@kbn/elastic-assistant-common'; @@ -25,5 +24,4 @@ export interface AgentState extends AgentStateBase { export interface NodeParamsBase { logger: Logger; - model: BaseChatModel; }