Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ export {
isToolValidationError,
isTokenLimitReachedError,
isToolNotFoundError,
type ChatCompleteMetadata,
type ConnectorTelemetryMetadata,
} from './src/chat_complete';
export {
OutputEventType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { Observable } from 'rxjs';
import type { ToolCallsOf, ToolOptions } from './tools';
import type { Message } from './messages';
import type { ChatCompletionEvent, ChatCompletionTokenCount } from './events';
import type { ChatCompleteMetadata } from './metadata';

/**
* Request a completion from the LLM based on a prompt or conversation.
Expand Down Expand Up @@ -109,6 +110,10 @@ export type ChatCompleteOptions<
* Optional signal that can be used to forcefully abort the request.
*/
abortSignal?: AbortSignal;
/**
* Optional metadata related to call execution.
*/
metadata?: ChatCompleteMetadata;
} & TToolOptions;

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export {
type UnvalidatedToolCall,
type ToolChoice,
} from './tools';
export type { ChatCompleteMetadata, ConnectorTelemetryMetadata } from './metadata';
export {
isChatCompletionChunkEvent,
isChatCompletionEvent,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

/**
* Set of metadata that can be used then calling the inference APIs
*
* @public
*/
export interface ChatCompleteMetadata {
connectorTelemetry?: ConnectorTelemetryMetadata;
}

/**
* Pass through for the connector telemetry
*/
export interface ConnectorTelemetryMetadata {
pluginId?: string;
aggregateBy?: string;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
*/

import type { Observable } from 'rxjs';
import { Message, FunctionCallingMode, FromToolSchema, ToolSchema } from '../chat_complete';
import {
Message,
FunctionCallingMode,
FromToolSchema,
ToolSchema,
ChatCompleteMetadata,
} from '../chat_complete';
import { Output, OutputEvent } from './events';

/**
Expand Down Expand Up @@ -117,6 +123,10 @@ export interface OutputOptions<
*/
onValidationError?: boolean | number;
};
/**
* Optional metadata related to call execution.
*/
metadata?: ChatCompleteMetadata;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { Logger } from '@kbn/logging';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { prepareMessages, DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION } from '../utils/bedrock';

export interface CustomChatModelInput extends BaseChatModelParams {
Expand All @@ -20,6 +21,7 @@ export interface CustomChatModelInput extends BaseChatModelParams {
signal?: AbortSignal;
model?: string;
maxTokens?: number;
telemetryMetadata?: TelemetryMetadata;
}

/**
Expand Down Expand Up @@ -49,6 +51,10 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
params: {
subAction: 'invokeAIRaw',
subActionParams: {
telemetryMetadata: {
pluginId: params?.telemetryMetadata?.pluginId,
aggregateBy: params?.telemetryMetadata?.aggregateBy,
},
messages: prepareMessages(inputBody.messages),
temperature: params.temperature ?? inputBody.temperature,
stopSequences: inputBody.stop_sequences,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,33 @@ import {
ConverseStreamCommand,
ConverseStreamResponse,
} from '@aws-sdk/client-bedrock-runtime';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { constructStack } from '@smithy/middleware-stack';
import { HttpHandlerOptions } from '@smithy/types';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { ActionsClient } from '@kbn/actions-plugin/server';

import { prepareMessages } from '../../utils/bedrock';

export interface CustomChatModelInput extends BedrockRuntimeClientConfig {
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
streaming?: boolean;
telemetryMetadata?: TelemetryMetadata;
}

export class BedrockRuntimeClient extends _BedrockRuntimeClient {
middlewareStack: _BedrockRuntimeClient['middlewareStack'];
streaming: boolean;
actionsClient: PublicMethodsOf<ActionsClient>;
connectorId: string;
telemetryMetadata?: TelemetryMetadata;

constructor({ actionsClient, connectorId, ...fields }: CustomChatModelInput) {
super(fields ?? {});
this.streaming = fields.streaming ?? true;
this.actionsClient = actionsClient;
this.connectorId = connectorId;
this.telemetryMetadata = fields?.telemetryMetadata;
// eliminate middleware steps that handle auth as Kibana connector handles auth
this.middlewareStack = constructStack() as _BedrockRuntimeClient['middlewareStack'];
}
Expand All @@ -56,6 +59,7 @@ export class BedrockRuntimeClient extends _BedrockRuntimeClient {
params: {
subAction: 'bedrockClientSend',
subActionParams: {
telemetryMetadata: this.telemetryMetadata,
command,
signal: options?.abortSignal,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { ActionsClient } from '@kbn/actions-plugin/server';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { Logger } from '@kbn/logging';
import { PublicMethodsOf } from '@kbn/utility-types';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { BedrockRuntimeClient } from './bedrock_runtime_client';
import { DEFAULT_BEDROCK_MODEL, DEFAULT_BEDROCK_REGION } from '../../utils/bedrock';

Expand All @@ -18,6 +19,7 @@ export interface CustomChatModelInput extends BaseChatModelParams {
logger: Logger;
signal?: AbortSignal;
model?: string;
telemetryMetadata?: TelemetryMetadata;
}

/**
Expand Down Expand Up @@ -45,6 +47,7 @@ export class ActionsClientChatBedrockConverse extends ChatBedrockConverse {
connectorId,
streaming: this.streaming,
region: DEFAULT_BEDROCK_REGION,
telemetryMetadata: fields?.telemetryMetadata,
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ import { v4 as uuidv4 } from 'uuid';
import { Logger } from '@kbn/core/server';
import type { ActionsClient } from '@kbn/actions-plugin/server';
import { get } from 'lodash/fp';

import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { ChatOpenAI } from '@langchain/openai';
import { Stream } from 'openai/streaming';
import type OpenAI from 'openai';
import { PublicMethodsOf } from '@kbn/utility-types';

import { DEFAULT_OPEN_AI_MODEL, DEFAULT_TIMEOUT } from './constants';
import {
InferenceChatCompleteParamsSchema,
Expand All @@ -36,6 +37,7 @@ export interface ActionsClientChatOpenAIParams {
temperature?: number;
signal?: AbortSignal;
timeout?: number;
telemetryMetadata?: TelemetryMetadata;
}

/**
Expand Down Expand Up @@ -65,6 +67,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
#traceId: string;
#signal?: AbortSignal;
#timeout?: number;
telemetryMetadata?: TelemetryMetadata;

constructor({
actionsClient,
Expand All @@ -79,6 +82,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
temperature,
timeout,
maxTokens,
telemetryMetadata,
}: ActionsClientChatOpenAIParams) {
super({
maxRetries,
Expand Down Expand Up @@ -109,6 +113,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
// matters only for LangSmith logs (Metadata > Invocation Params)
// the connector can be passed an undefined temperature through #temperature
this.temperature = temperature ?? this.temperature;
this.telemetryMetadata = telemetryMetadata;
}

getActionResultData(): string {
Expand Down Expand Up @@ -237,6 +242,7 @@ export class ActionsClientChatOpenAI extends ChatOpenAI {
: completionRequest.stream
? { ...body, timeout: this.#timeout ?? DEFAULT_TIMEOUT }
: { body: JSON.stringify(body), timeout: this.#timeout ?? DEFAULT_TIMEOUT }),
telemetryMetadata: this.telemetryMetadata,
signal: this.#signal,
};
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import { Logger } from '@kbn/logging';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiPartText } from '@langchain/google-common/dist/types';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import {
convertResponseBadFinishReasonToErrorMsg,
convertResponseContentToChatGenerationChunk,
Expand All @@ -34,12 +35,14 @@ export interface CustomChatModelInput extends BaseChatModelParams {
signal?: AbortSignal;
model?: string;
maxTokens?: number;
telemetryMetadata?: TelemetryMetadata;
}

export class ActionsClientChatVertexAI extends ChatVertexAI {
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
#model?: string;
telemetryMetadata?: TelemetryMetadata;
constructor({ actionsClient, connectorId, ...props }: CustomChatModelInput) {
super({
...props,
Expand All @@ -62,7 +65,8 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
client,
false,
actionsClient,
connectorId
connectorId,
props?.telemetryMetadata
);
}

Expand All @@ -89,6 +93,7 @@ export class ActionsClientChatVertexAI extends ChatVertexAI {
subAction: 'invokeStream',
subActionParams: {
model: this.#model,
telemetryMetadata: this.telemetryMetadata,
messages: data?.contents,
tools: data?.tools,
temperature: this.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { ActionsClient } from '@kbn/actions-plugin/server';
import { PublicMethodsOf } from '@kbn/utility-types';
import { EnhancedGenerateContentResponse } from '@google/generative-ai';
import { AsyncCaller } from '@langchain/core/utils/async_caller';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import { convertResponseBadFinishReasonToErrorMsg } from '../../utils/gemini';

// only implements non-streaming requests
Expand All @@ -26,17 +27,20 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
#model?: string;
temperature: number;
caller: AsyncCaller;
telemetryMetadata?: TelemetryMetadata;
constructor(
fields: GoogleAIBaseLLMInput<Auth>,
caller: AsyncCaller,
client: GoogleAbstractedClient,
_streaming: boolean, // defaulting to false in the super
actionsClient: PublicMethodsOf<ActionsClient>,
connectorId: string
connectorId: string,
telemetryMetadata?: TelemetryMetadata
) {
super(fields, caller, client, false);
this.actionsClient = actionsClient;
this.connectorId = connectorId;
this.telemetryMetadata = telemetryMetadata;
this.caller = caller;
this.#model = fields.model;
this.temperature = fields.temperature ?? 0;
Expand Down Expand Up @@ -77,6 +81,7 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
params: {
subAction: 'invokeAIRaw',
subActionParams: {
telemetryMetadata: this.telemetryMetadata,
model: this.#model,
messages: data?.contents,
tools: data?.tools,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { Logger } from '@kbn/logging';
import { BaseChatModelParams } from '@langchain/core/language_models/chat_models';
import { get } from 'lodash/fp';
import { Readable } from 'stream';
import type { TelemetryMetadata } from '@kbn/actions-plugin/server/lib';
import {
convertBaseMessagesToContent,
convertResponseBadFinishReasonToErrorMsg,
Expand All @@ -36,20 +37,23 @@ export interface CustomChatModelInput extends BaseChatModelParams {
signal?: AbortSignal;
model?: string;
maxTokens?: number;
telemetryMetadata?: TelemetryMetadata;
}

export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
#actionsClient: PublicMethodsOf<ActionsClient>;
#connectorId: string;
#temperature: number;
#model?: string;
telemetryMetadata?: TelemetryMetadata;

constructor({ actionsClient, connectorId, ...props }: CustomChatModelInput) {
super({
...props,
apiKey: 'asda',
maxOutputTokens: props.maxTokens ?? 2048,
});
this.telemetryMetadata = props.telemetryMetadata;
// LangChain needs model to be defined for logging purposes
this.model = props.model ?? this.model;
// If model is not specified by consumer, the connector will defin eit so do not pass
Expand All @@ -71,6 +75,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
params: {
subAction: 'invokeAIRaw',
subActionParams: {
telemetryMetadata: this.telemetryMetadata,
model: this.#model,
messages: request.contents,
tools: request.tools,
Expand Down Expand Up @@ -159,6 +164,7 @@ export class ActionsClientGeminiChatModel extends ChatGoogleGenerativeAI {
}, []),
temperature: this.#temperature,
tools: request.tools,
telemetryMetadata: this.telemetryMetadata,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ describe('ActionsClientSimpleChatModel', () => {
temperature: 0,
stopSequences: ['\n'],
maxTokens: 333,
model: undefined,
telemetryMetadata: {
aggregateBy: undefined,
pluginId: undefined,
},
});

expect(result).toEqual(mockActionResponse.message);
Expand All @@ -252,6 +257,11 @@ describe('ActionsClientSimpleChatModel', () => {

expect(rest).toEqual({
temperature: 0,
model: undefined,
telemetryMetadata: {
aggregateBy: undefined,
pluginId: undefined,
},
});

expect(result).toEqual(mockActionResponse.message);
Expand Down
Loading