Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,23 +57,32 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
const isOpenAI = llmType === 'openai';
const llmClass = getLlmClass(llmType, bedrockChatEnabled);

const llm = new llmClass({
actionsClient,
connectorId,
llmType,
logger,
// possible client model override,
// let this be undefined otherwise so the connector handles the model
model: request.body.model,
// ensure this is defined because we default to it in the language_models
// This is where the LangSmith logs (Metadata > Invocation Params) are set
temperature: getDefaultArguments(llmType).temperature,
signal: abortSignal,
streaming: isStream,
// prevents the agent from retrying on failure
// failure could be due to bad connector, we should deliver that result to the client asap
maxRetries: 0,
});
/**
* Creates a new instance of llmClass.
*
* This function ensures that a new llmClass instance is created every time it is called.
* This is necessary to avoid any potential side effects from shared state. By always
* creating a new instance, we prevent other uses of llm from binding and changing
* the state unintentionally. For this reason, never assign this value to a variable (ex const llm = createLlmInstance())
*/
const createLlmInstance = () =>
new llmClass({
actionsClient,
connectorId,
llmType,
logger,
// possible client model override,
// let this be undefined otherwise so the connector handles the model
model: request.body.model,
// ensure this is defined because we default to it in the language_models
// This is where the LangSmith logs (Metadata > Invocation Params) are set
temperature: getDefaultArguments(llmType).temperature,
signal: abortSignal,
streaming: isStream,
// prevents the agent from retrying on failure
// failure could be due to bad connector, we should deliver that result to the client asap
maxRetries: 0,
});

const anonymizationFieldsRes =
await dataClients?.anonymizationFieldsDataClient?.findDocuments<EsAnonymizationFieldsSchema>({
Expand All @@ -99,7 +108,7 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
const modelExists = await esStore.isModelInstalled();

// Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now
const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever(10));
const chain = RetrievalQAChain.fromLLM(createLlmInstance(), esStore.asRetriever(10));

// Fetch any applicable tools that the source plugin may have registered
const assistantToolParams: AssistantToolParams = {
Expand All @@ -108,7 +117,6 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
chain,
esClient,
isEnabledKnowledgeBase: true,
llm,
logger,
modelExists,
onNewReplacements,
Expand All @@ -118,7 +126,7 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
};

const tools: ToolInterface[] = assistantTools.flatMap(
(tool) => tool.getTool(assistantToolParams) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
);

logger.debug(`applicable tools: ${JSON.stringify(tools.map((t) => t.name).join(', '), null, 2)}`);
Expand All @@ -130,14 +138,14 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
};
// isOpenAI check is not on agentType alone because typescript doesn't like
const executor = isOpenAI
? await initializeAgentExecutorWithOptions(tools, llm, {
? await initializeAgentExecutorWithOptions(tools, createLlmInstance(), {
agentType: 'openai-functions',
...executorArgs,
})
: llmType === 'bedrock' && bedrockChatEnabled
? new lcAgentExecutor({
agent: await createToolCallingAgent({
llm,
llm: createLlmInstance(),
tools,
prompt: ChatPromptTemplate.fromMessages([
['system', 'You are a helpful assistant'],
Expand All @@ -149,7 +157,7 @@ export const callAgentExecutor: AgentExecutor<true | false> = async ({
}),
tools,
})
: await initializeAgentExecutorWithOptions(tools, llm, {
: await initializeAgentExecutorWithOptions(tools, createLlmInstance(), {
agentType: 'structured-chat-zero-shot-react-description',
...executorArgs,
returnIntermediateSteps: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ interface GetDefaultAssistantGraphParams {
agentRunnable: AgentRunnableSequence;
dataClients?: AssistantDataClients;
conversationId?: string;
getLlmInstance: () => BaseChatModel;
llm: BaseChatModel;
createLlmInstance: () => BaseChatModel;
logger: Logger;
tools: StructuredTool[];
responseLanguage: string;
Expand All @@ -61,8 +60,7 @@ export const getDefaultAssistantGraph = ({
agentRunnable,
conversationId,
dataClients,
getLlmInstance,
llm,
createLlmInstance,
logger,
responseLanguage,
tools,
Expand Down Expand Up @@ -106,7 +104,6 @@ export const getDefaultAssistantGraph = ({

// Default node parameters
const nodeParams: NodeParamsBase = {
model: llm,
logger,
};

Expand All @@ -131,6 +128,7 @@ export const getDefaultAssistantGraph = ({
const generateChatTitleNode = (state: AgentState) =>
generateChatTitle({
...nodeParams,
model: createLlmInstance(),
state,
responseLanguage,
});
Expand All @@ -154,7 +152,7 @@ export const getDefaultAssistantGraph = ({
const respondNode = (state: AgentState) =>
respond({
...nodeParams,
llm: getLlmInstance(),
model: createLlmInstance(),
state,
});
const shouldContinueEdge = (state: AgentState) => shouldContinue({ ...nodeParams, state });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,16 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
const logger = parentLogger.get('defaultAssistantGraph');
const isOpenAI = llmType === 'openai';
const llmClass = getLlmClass(llmType, bedrockChatEnabled);
const getLlmInstance = () =>

/**
* Creates a new instance of llmClass.
*
* This function ensures that a new llmClass instance is created every time it is called.
* This is necessary to avoid any potential side effects from shared state. By always
* creating a new instance, we prevent other uses of llm from binding and changing
* the state unintentionally. For this reason, never assign this value to a variable (ex const llm = createLlmInstance())
*/
const createLlmInstance = () =>
new llmClass({
actionsClient,
connectorId,
Expand All @@ -76,8 +85,6 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
maxRetries: 0,
});

const llm = getLlmInstance();

const anonymizationFieldsRes =
await dataClients?.anonymizationFieldsDataClient?.findDocuments<EsAnonymizationFieldsSchema>({
perPage: 1000,
Expand All @@ -93,7 +100,7 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
const modelExists = await esStore.isModelInstalled();

// Create a chain that uses the ELSER backed ElasticsearchStore, override k=10 for esql query generation for now
const chain = RetrievalQAChain.fromLLM(getLlmInstance(), esStore.asRetriever(10));
const chain = RetrievalQAChain.fromLLM(createLlmInstance(), esStore.asRetriever(10));

// Check if KB is available
const isEnabledKnowledgeBase = (await dataClients?.kbDataClient?.isModelDeployed()) ?? false;
Expand All @@ -106,7 +113,6 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
esClient,
isEnabledKnowledgeBase,
kbDataClient: dataClients?.kbDataClient,
llm,
logger,
modelExists,
onNewReplacements,
Expand All @@ -116,26 +122,26 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
};

const tools: StructuredTool[] = assistantTools.flatMap(
(tool) => tool.getTool(assistantToolParams) ?? []
(tool) => tool.getTool({ ...assistantToolParams, llm: createLlmInstance() }) ?? []
);

const agentRunnable = isOpenAI
? await createOpenAIFunctionsAgent({
llm,
llm: createLlmInstance(),
tools,
prompt: openAIFunctionAgentPrompt,
streamRunnable: isStream,
})
: llmType && ['bedrock', 'gemini'].includes(llmType) && bedrockChatEnabled
? await createToolCallingAgent({
llm,
llm: createLlmInstance(),
tools,
prompt:
llmType === 'bedrock' ? bedrockToolCallingAgentPrompt : geminiToolCallingAgentPrompt,
streamRunnable: isStream,
})
: await createStructuredChatAgent({
llm,
llm: createLlmInstance(),
tools,
prompt: structuredChatAgentPrompt,
streamRunnable: isStream,
Expand All @@ -147,9 +153,8 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
agentRunnable,
conversationId,
dataClients,
llm,
// we need to pass it like this or streaming does not work for bedrock
getLlmInstance,
createLlmInstance,
logger,
tools,
responseLanguage,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import { StringOutputParser } from '@langchain/core/output_parsers';

import { ChatPromptTemplate } from '@langchain/core/prompts';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import { AgentState, NodeParamsBase } from '../types';

export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) =>
Expand All @@ -25,6 +26,7 @@ export const GENERATE_CHAT_TITLE_PROMPT = (responseLanguage: string) =>
export interface GenerateChatTitleParams extends NodeParamsBase {
responseLanguage: string;
state: AgentState;
model: BaseChatModel;
}

export const GENERATE_CHAT_TITLE_NODE = 'generateChatTitle';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { AGENT_NODE_TAG } from './run_agent';
import { AgentState } from '../types';

export const RESPOND_NODE = 'respond';
export const respond = async ({ llm, state }: { llm: BaseChatModel; state: AgentState }) => {
export const respond = async ({ model, state }: { model: BaseChatModel; state: AgentState }) => {
if (state?.agentOutcome && 'returnValues' in state.agentOutcome) {
const userMessage = [
'user',
Expand All @@ -21,7 +21,7 @@ export const respond = async ({ llm, state }: { llm: BaseChatModel; state: Agent
Do not verify, confirm or anything else. Just reply with the same content as provided above.`,
] as [StringWithAutocomplete<'user'>, string];

const responseMessage = await llm
const responseMessage = await model
// use AGENT_NODE_TAG to identify as agent node for stream parsing
.withConfig({ runName: 'Summarizer', tags: [AGENT_NODE_TAG] })
.invoke([userMessage]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import { BaseMessage } from '@langchain/core/messages';
import { AgentAction, AgentFinish, AgentStep } from '@langchain/core/agents';
import { BaseChatModel } from '@langchain/core/language_models/chat_models';
import type { Logger } from '@kbn/logging';
import { ConversationResponse } from '@kbn/elastic-assistant-common';

Expand All @@ -25,5 +24,4 @@ export interface AgentState extends AgentStateBase {

export interface NodeParamsBase {
logger: Logger;
model: BaseChatModel;
}