Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import { LLM } from '@langchain/core/language_models/llms';
import { get } from 'lodash/fp';
import { v4 as uuidv4 } from 'uuid';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { DEFAULT_TIMEOUT, getDefaultArguments } from './constants';

import { getMessageContentAndRole } from './helpers';
Expand All @@ -28,6 +29,7 @@ interface ActionsClientLlmParams {
timeout?: number;
traceId?: string;
traceOptions?: TraceOptions;
telemetryMetadata?: TelemetryMetadata;
}

export class ActionsClientLlm extends LLM {
Expand All @@ -36,6 +38,7 @@ export class ActionsClientLlm extends LLM {
#logger: Logger;
#traceId: string;
#timeout?: number;
telemetryMetadata?: TelemetryMetadata;

// Local `llmType` as it can change and needs to be accessed by abstract `_llmType()` method
// Not using getter as `this._llmType()` is called in the constructor via `super({})`
Expand All @@ -54,6 +57,7 @@ export class ActionsClientLlm extends LLM {
temperature,
timeout,
traceOptions,
telemetryMetadata,
}: ActionsClientLlmParams) {
super({
callbacks: [...(traceOptions?.tracers ?? [])],
Expand All @@ -67,6 +71,7 @@ export class ActionsClientLlm extends LLM {
this.#timeout = timeout;
this.model = model;
this.temperature = temperature;
this.telemetryMetadata = telemetryMetadata;
}

_llmType() {
Expand Down Expand Up @@ -102,6 +107,7 @@ export class ActionsClientLlm extends LLM {
model: this.model,
messages: [assistantMessage], // the assistant message
},
telemetryMetadata: this.telemetryMetadata,
},
}
: {
Expand All @@ -113,6 +119,7 @@ export class ActionsClientLlm extends LLM {
...getDefaultArguments(this.llmType, this.temperature),
// This timeout is large because LangChain prompts can be complicated and take a long time
timeout: this.#timeout ?? DEFAULT_TIMEOUT,
telemetryMetadata: this.telemetryMetadata,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ export function registerAnalyzeLogsRoutes(
maxTokens: 4096,
signal: abortSignal,
streaming: false,
telemetryMetadata: {
pluginId: 'automatic_import',
},
});
const options = {
callbacks: [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ export function registerCategorizationRoutes(
maxTokens: 4096,
signal: abortSignal,
streaming: false,
telemetryMetadata: {
pluginId: 'automatic_import',
},
});

const parameters = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ export function registerCelInputRoutes(router: IRouter<IntegrationAssistantRoute
maxTokens: 4096,
signal: abortSignal,
streaming: false,
telemetryMetadata: {
pluginId: 'automatic_import',
},
});

const parameters = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ export function registerEcsRoutes(router: IRouter<IntegrationAssistantRouteHandl
maxTokens: 4096,
signal: abortSignal,
streaming: false,
telemetryMetadata: {
pluginId: 'automatic_import',
},
});

const parameters = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ export function registerRelatedRoutes(router: IRouter<IntegrationAssistantRouteH
maxTokens: 4096,
signal: abortSignal,
streaming: false,
telemetryMetadata: {
pluginId: 'automatic_import',
},
});

const parameters = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ describe('InferenceConnector', () => {

const response = await connector.performApiUnifiedCompletion({
body: { messages: [{ content: 'What is Elastic?', role: 'user' }] },
telemetryMetadata: { pluginId: 'security_ai_assistant' },
});
expect(mockEsClient.transport.request).toBeCalledTimes(1);
expect(mockEsClient.transport.request).toHaveBeenCalledWith(
Expand All @@ -86,7 +87,13 @@ describe('InferenceConnector', () => {
method: 'POST',
path: '_inference/chat_completion/test/_stream',
},
{ asStream: true, meta: true }
{
asStream: true,
meta: true,
headers: {
'X-Elastic-Product-Use-Case': 'security_ai_assistant',
},
}
);
expect(response.choices[0].message.content).toEqual(' you');
});
Expand Down Expand Up @@ -292,7 +299,10 @@ describe('InferenceConnector', () => {
method: 'POST',
path: '_inference/chat_completion/test/_stream',
},
{ asStream: true, meta: true }
{
asStream: true,
meta: true,
}
);
});

Expand All @@ -314,7 +324,11 @@ describe('InferenceConnector', () => {
method: 'POST',
path: '_inference/chat_completion/test/_stream',
},
{ asStream: true, meta: true, signal }
{
asStream: true,
meta: true,
signal,
}
);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,13 @@ export class InferenceConnector extends SubActionConnector<Config, Secrets> {
asStream: true,
meta: true,
signal: params.signal,
...(params.telemetryMetadata?.pluginId
? {
headers: {
'X-Elastic-Product-Use-Case': params.telemetryMetadata?.pluginId,
},
}
: {}),
}
);
// errors should be thrown as it will not be a stream response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,8 @@ export const getEvaluatorLlm = async ({
temperature: 0, // zero temperature for evaluation
timeout: connectorTimeout,
traceOptions,
telemetryMetadata: {
pluginId: 'security_attack_discovery',
},
});
};
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ export const evaluateAttackDiscovery = async ({
temperature: 0, // zero temperature for attack discovery, because we want structured JSON output
timeout: connectorTimeout,
traceOptions,
telemetryMetadata: {
pluginId: 'security_attack_discovery',
},
});

const graph = getDefaultAttackDiscoveryGraph({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ export const callAssistantGraph: AgentExecutor<true | false> = async ({
// failure could be due to bad connector, we should deliver that result to the client asap
maxRetries: 0,
convertSystemMessageToHumanContent: false,
telemetryMetadata: {
pluginId: 'security_ai_assistant',
},
});

const anonymizationFieldsRes =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ export const invokeAttackDiscoveryGraph = async ({
temperature: 0, // zero temperature for attack discovery, because we want structured JSON output
timeout: connectorTimeout,
traceOptions,
telemetryMetadata: {
pluginId: 'security_attack_discovery',
},
});

if (llm == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,9 @@ export function getAssistantToolParams({
temperature: 0, // zero temperature because we want structured JSON output
timeout: connectorTimeout,
traceOptions,
telemetryMetadata: {
pluginId: 'security_defend_insights',
},
});

return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ export const postEvaluateRoute = (
streaming: false,
maxRetries: 0,
convertSystemMessageToHumanContent: false,
telemetryMetadata: {
pluginId: 'security_ai_assistant',
},
});
const llm = createLlmInstance();
const anonymizationFieldsRes =
Expand Down