Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions api_docs/kbn_elastic_assistant_common.devdocs.json
Original file line number Diff line number Diff line change
Expand Up @@ -6031,15 +6031,15 @@
},
{
"parentPluginId": "@kbn/elastic-assistant-common",
"id": "def-common.INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG",
"id": "def-common.INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG",
"type": "string",
"tags": [],
"label": "INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG",
"label": "INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG",
"description": [
"\nThis feature flag enables the InferenceChatModel feature.\n\nIt may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`:\n```\nfeature_flags.overrides:\n securitySolution.inferenceChatModelEnabled: true\n```"
"\nThis feature flag disables the InferenceChatModel feature.\n\nIt may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`:\n```\nfeature_flags.overrides:\n securitySolution.inferenceChatModelDisabled: true\n```"
],
"signature": [
"\"securitySolution.inferenceChatModelEnabled\""
"\"securitySolution.inferenceChatModelDisabled\""
],
"path": "x-pack/platform/packages/shared/kbn-elastic-assistant-common/constants.ts",
"deprecated": false,
Expand Down Expand Up @@ -12022,4 +12022,4 @@
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ export const ATTACK_DISCOVERY_ALERTS_COMMON_INDEX_PREFIX =
'.alerts-security.attack.discovery.alerts' as const;

/**
* This feature flag enables the InferenceChatModel feature.
* This feature flag disables the InferenceChatModel feature.
*
* It may be overridden via the following setting in `kibana.yml` or `kibana.dev.yml`:
* ```
* feature_flags.overrides:
* securitySolution.inferenceChatModelEnabled: true
* securitySolution.inferenceChatModelDisabled: true
* ```
*/
export const INFERENCE_CHAT_MODEL_ENABLED_FEATURE_FLAG =
'securitySolution.inferenceChatModelEnabled' as const;
export const INFERENCE_CHAT_MODEL_DISABLED_FEATURE_FLAG =
'securitySolution.inferenceChatModelDisabled' as const;
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ export interface AgentExecutorParams<T extends boolean> {
llmType?: string;
isOssModel?: boolean;
inference: InferenceServerStart;
inferenceChatModelEnabled?: boolean;
inferenceChatModelDisabled?: boolean;
logger: Logger;
onNewReplacements?: (newReplacements: Replacements) => void;
replacements: Replacements;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export const agentRunnableFactory = async ({
llm,
llmType,
tools,
inferenceChatModelEnabled,
inferenceChatModelDisabled,
isOpenAI,
isStream,
prompt,
Expand All @@ -37,7 +37,7 @@ export const agentRunnableFactory = async ({
| ActionsClientChatVertexAI
| ActionsClientChatOpenAI
| InferenceChatModel;
inferenceChatModelEnabled: boolean;
inferenceChatModelDisabled: boolean;
isOpenAI: boolean;
llmType: string | undefined;
tools: StructuredToolInterface[] | ToolDefinition[];
Expand All @@ -51,7 +51,7 @@ export const agentRunnableFactory = async ({
prompt,
} as const;

if (!inferenceChatModelEnabled && (isOpenAI || llmType === 'inference')) {
if (inferenceChatModelDisabled && (isOpenAI || llmType === 'inference')) {
return createOpenAIToolsAgent(params);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ describe('streamGraph', () => {
logger: mockLogger,
onLlmResponse: mockOnLlmResponse,
request: mockRequest,
inferenceChatModelDisabled: true,
isEnabledKnowledgeBase: false,
telemetry: {
reportEvent: jest.fn(),
Expand All @@ -80,154 +81,156 @@ describe('streamGraph', () => {
});
});
describe('OpenAI Function Agent streaming', () => {
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
],
describe('Inference Chat Model Disabled', () => {
it('should execute the graph in streaming mode - OpenAI + isOssModel = false', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [
[{ generationInfo: { finish_reason: 'stop' }, text: 'final message' }],
],
},
},
},
tags: [AGENT_NODE_TAG],
};
},
});
tags: [AGENT_NODE_TAG],
};
},
});

const response = await streamGraph(requestArgs);
const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
it('on_llm_end events with finish_reason != stop should not end the stream', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: { finish_reason: 'function_call' }, text: '' }]],
},
},
},
tags: [AGENT_NODE_TAG],
};
},
});
tags: [AGENT_NODE_TAG],
};
},
});

const response = await streamGraph(requestArgs);
const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).not.toHaveBeenCalled();
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
});
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
it('on_llm_end events without a finish_reason should end the stream', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: 'final message' }]],
},
},
},
tags: [AGENT_NODE_TAG],
};
},
});
tags: [AGENT_NODE_TAG],
};
},
});

const response = await streamGraph(requestArgs);
const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
});
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
it('on_llm_end events is called with chunks if there is no final text value', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {
output: {
generations: [[{ generationInfo: {}, text: '' }]],
},
},
},
tags: [AGENT_NODE_TAG],
};
},
});
tags: [AGENT_NODE_TAG],
};
},
});

const response = await streamGraph(requestArgs);
const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {},
tags: [AGENT_NODE_TAG],
};
},
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
);
});
});
it('on_llm_end does not call handleStreamEnd if generations is undefined', async () => {
mockStreamEvents.mockReturnValue({
async *[Symbol.asyncIterator]() {
yield {
event: 'on_llm_stream',
data: { chunk: { message: { content: 'content' } } },
tags: [AGENT_NODE_TAG],
};
yield {
event: 'on_llm_end',
data: {},
tags: [AGENT_NODE_TAG],
};
},
});

const response = await streamGraph(requestArgs);
const response = await streamGraph(requestArgs);

expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).not.toHaveBeenCalled();
expect(response).toBe(mockResponseWithHeaders);
await waitFor(() => {
expect(mockPush).toHaveBeenCalledWith({ payload: 'content', type: 'content' });
expect(mockOnLlmResponse).not.toHaveBeenCalled();
});
});
});
});
Expand Down Expand Up @@ -337,5 +340,16 @@ describe('streamGraph', () => {
});
await expectConditions(response, true);
});
it('should execute the graph in streaming mode - OpenAI + inferenceChatModelDisabled = false', async () => {
const mockAssistantGraphAsyncIterator = {
streamEvents: () => mockAsyncIterator,
} as unknown as DefaultAssistantGraph;
const response = await streamGraph({
...requestArgs,
inferenceChatModelDisabled: false,
assistantGraph: mockAssistantGraphAsyncIterator,
});
await expectConditions(response);
});
});
});
Loading