Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion x-pack/plugins/actions/server/lib/gen_ai_token_tracking.ts
Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,7 @@ export const getGenAiTokenTracking = async ({
};

export const shouldTrackGenAiToken = (actionTypeId: string) =>
actionTypeId === '.gen-ai' || actionTypeId === '.bedrock' || actionTypeId === '.gemini';
actionTypeId === '.gen-ai' ||
actionTypeId === '.bedrock' ||
actionTypeId === '.gemini' ||
actionTypeId === '.inference';
3 changes: 3 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ export enum ServiceProviderKeys {

export const INFERENCE_CONNECTOR_ID = '.inference';
export enum SUB_ACTION {
UNIFIED_COMPLETION_ASYNC_ITERATOR = 'unified_completion_async_iterator',
UNIFIED_COMPLETION_STREAM = 'unified_completion_stream',
UNIFIED_COMPLETION = 'unified_completion',
COMPLETION = 'completion',
RERANK = 'rerank',
TEXT_EMBEDDING = 'text_embedding',
Expand Down
179 changes: 179 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,176 @@ export const ChatCompleteParamsSchema = schema.object({
input: schema.string(),
});

// subset of OpenAI.ChatCompletionMessageParam https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
const AIMessage = schema.object({
role: schema.string(),
content: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.string(),
function: schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
}),
type: schema.string(),
})
)
),
tool_call_id: schema.maybe(schema.string()),
});

const AITool = schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
description: schema.maybe(schema.string()),
parameters: schema.maybe(schema.recordOf(schema.string(), schema.any())),
}),
});

// subset of OpenAI.ChatCompletionCreateParamsBase https://github.com/openai/openai-node/blob/master/src/resources/chat/completions.ts
export const UnifiedChatCompleteParamsSchema = schema.object({
body: schema.object({
messages: schema.arrayOf(AIMessage, { defaultValue: [] }),
model: schema.maybe(schema.string()),
/**
* The maximum number of [tokens](/tokenizer) that can be generated in the chat
* completion. This value can be used to control
* [costs](https://openai.com/api/pricing/) for text generated via API.
*
* This value is now deprecated in favor of `max_completion_tokens`, and is not
* compatible with
* [o1 series models](https://platform.openai.com/docs/guides/reasoning).
*/
max_tokens: schema.maybe(schema.number()),
/**
* Developer-defined tags and values used for filtering completions in the
* [dashboard](https://platform.openai.com/chat-completions).
*/
metadata: schema.maybe(schema.recordOf(schema.string(), schema.string())),
/**
* How many chat completion choices to generate for each input message. Note that
* you will be charged based on the number of generated tokens across all of the
* choices. Keep `n` as `1` to minimize costs.
*/
n: schema.maybe(schema.number()),
/**
* Up to 4 sequences where the API will stop generating further tokens.
*/
stop: schema.maybe(
schema.nullable(schema.oneOf([schema.string(), schema.arrayOf(schema.string())]))
),
/**
* What sampling temperature to use, between 0 and 2. Higher values like 0.8 will
* make the output more random, while lower values like 0.2 will make it more
* focused and deterministic.
*
* We generally recommend altering this or `top_p` but not both.
*/
temperature: schema.maybe(schema.number()),
/**
* Controls which (if any) tool is called by the model. `none` means the model will
* not call any tool and instead generates a message. `auto` means the model can
* pick between generating a message or calling one or more tools. `required` means
* the model must call one or more tools. Specifying a particular tool via
* `{"type": "function", "function": {"name": "my_function"}}` forces the model to
* call that tool.
*
* `none` is the default when no tools are present. `auto` is the default if tools
* are present.
*/
tool_choice: schema.maybe(
schema.oneOf([
schema.string(),
schema.object({
type: schema.string(),
function: schema.object({
name: schema.string(),
}),
}),
])
),
/**
* A list of tools the model may call. Currently, only functions are supported as a
* tool. Use this to provide a list of functions the model may generate JSON inputs
* for. A max of 128 functions are supported.
*/
tools: schema.maybe(schema.arrayOf(AITool)),
/**
* An alternative to sampling with temperature, called nucleus sampling, where the
* model considers the results of the tokens with top_p probability mass. So 0.1
* means only the tokens comprising the top 10% probability mass are considered.
*
* We generally recommend altering this or `temperature` but not both.
*/
top_p: schema.maybe(schema.number()),
/**
* A unique identifier representing your end-user, which can help OpenAI to monitor
* and detect abuse.
* [Learn more](https://platform.openai.com/docs/guides/safety-best-practices/end-user-ids).
*/
user: schema.maybe(schema.string()),
}),
// abort signal from client
signal: schema.maybe(schema.any()),
});

export const UnifiedChatCompleteResponseSchema = schema.object({
id: schema.string(),
choices: schema.arrayOf(
schema.object({
finish_reason: schema.maybe(
schema.nullable(
schema.oneOf([
schema.literal('stop'),
schema.literal('length'),
schema.literal('tool_calls'),
schema.literal('content_filter'),
schema.literal('function_call'),
])
)
),
index: schema.maybe(schema.number()),
message: schema.object({
content: schema.maybe(schema.nullable(schema.string())),
refusal: schema.maybe(schema.nullable(schema.string())),
role: schema.maybe(schema.string()),
tool_calls: schema.maybe(
schema.arrayOf(
schema.object({
id: schema.maybe(schema.string()),
index: schema.maybe(schema.number()),
function: schema.maybe(
schema.object({
arguments: schema.maybe(schema.string()),
name: schema.maybe(schema.string()),
})
),
type: schema.maybe(schema.string()),
}),
{ defaultValue: [] }
)
),
}),
}),
{ defaultValue: [] }
),
created: schema.maybe(schema.number()),
model: schema.maybe(schema.string()),
object: schema.maybe(schema.string()),
usage: schema.maybe(
schema.nullable(
schema.object({
completion_tokens: schema.maybe(schema.number()),
prompt_tokens: schema.maybe(schema.number()),
total_tokens: schema.maybe(schema.number()),
})
)
),
});

export const ChatCompleteResponseSchema = schema.arrayOf(
schema.object({
result: schema.string(),
Expand Down Expand Up @@ -66,3 +236,12 @@ export const TextEmbeddingResponseSchema = schema.arrayOf(
);

export const StreamingResponseSchema = schema.stream();

// Run action schema
export const DashboardActionParamsSchema = schema.object({
dashboardId: schema.string(),
});

export const DashboardActionResponseSchema = schema.object({
available: schema.boolean(),
});
10 changes: 10 additions & 0 deletions x-pack/plugins/stack_connectors/common/inference/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ import {
SparseEmbeddingResponseSchema,
TextEmbeddingParamsSchema,
TextEmbeddingResponseSchema,
UnifiedChatCompleteParamsSchema,
UnifiedChatCompleteResponseSchema,
DashboardActionParamsSchema,
DashboardActionResponseSchema,
} from './schema';
import { ConfigProperties } from '../dynamic_config/types';

export type Config = TypeOf<typeof ConfigSchema>;
export type Secrets = TypeOf<typeof SecretsSchema>;

export type UnifiedChatCompleteParams = TypeOf<typeof UnifiedChatCompleteParamsSchema>;
export type UnifiedChatCompleteResponse = TypeOf<typeof UnifiedChatCompleteResponseSchema>;

export type ChatCompleteParams = TypeOf<typeof ChatCompleteParamsSchema>;
export type ChatCompleteResponse = TypeOf<typeof ChatCompleteResponseSchema>;

Expand All @@ -38,6 +45,9 @@ export type TextEmbeddingResponse = TypeOf<typeof TextEmbeddingResponseSchema>;

export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;

export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;

export type FieldsConfiguration = Record<string, ConfigProperties>;

export interface InferenceProvider {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,27 @@ export const DEFAULT_TEXT_EMBEDDING_BODY = {
inputType: 'ingest',
};

export const DEFAULT_UNIFIED_CHAT_COMPLETE_BODY = {
body: {
messages: [
{
role: 'user',
content: 'Hello world',
},
],
},
};

export const DEFAULTS_BY_TASK_TYPE: Record<string, unknown> = {
[SUB_ACTION.COMPLETION]: DEFAULT_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_STREAM]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR]: DEFAULT_UNIFIED_CHAT_COMPLETE_BODY,
[SUB_ACTION.RERANK]: DEFAULT_RERANK_BODY,
[SUB_ACTION.SPARSE_EMBEDDING]: DEFAULT_SPARSE_EMBEDDING_BODY,
[SUB_ACTION.TEXT_EMBEDDING]: DEFAULT_TEXT_EMBEDDING_BODY,
};

export const DEFAULT_TASK_TYPE = 'completion';
export const DEFAULT_TASK_TYPE = 'unified_completion';

export const DEFAULT_PROVIDER = 'elasticsearch';
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,16 @@ describe('OpenAI action params validation', () => {
subActionParams: { input: ['message test'], query: 'foobar' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_STREAM,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.UNIFIED_COMPLETION_ASYNC_ITERATOR,
subActionParams: { body: { messages: [{ role: 'user', content: 'What is Elastic?' }] } },
},
{
subAction: SUB_ACTION.TEXT_EMBEDDING,
Expand All @@ -53,6 +61,10 @@ describe('OpenAI action params validation', () => {
subAction: SUB_ACTION.SPARSE_EMBEDDING,
subActionParams: { input: 'message test' },
},
{
subAction: SUB_ACTION.COMPLETION,
subActionParams: { input: 'message test' },
},
])(
'validation succeeds when params are valid for subAction $subAction',
async ({ subAction, subActionParams }) => {
Expand All @@ -61,19 +73,25 @@ describe('OpenAI action params validation', () => {
subActionParams,
};
expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: [], subAction: [], inputType: [], query: [] },
errors: { body: [], input: [], subAction: [], inputType: [], query: [] },
});
}
);

test('params validation fails when params is a wrong object', async () => {
const actionParams = {
subAction: SUB_ACTION.COMPLETION,
subAction: SUB_ACTION.UNIFIED_COMPLETION,
subActionParams: { body: 'message {test}' },
};

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: { input: ['Input is required.'], inputType: [], query: [], subAction: [] },
errors: {
body: ['Messages is required.'],
inputType: [],
query: [],
subAction: [],
input: [],
},
});
});

Expand All @@ -84,6 +102,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -100,6 +119,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: [],
query: [],
Expand All @@ -116,6 +136,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: ['Input is required.', 'Input does not have a valid Array format.'],
inputType: [],
query: ['Query is required.'],
Expand All @@ -132,6 +153,7 @@ describe('OpenAI action params validation', () => {

expect(await actionTypeModel.validateParams(actionParams)).toEqual({
errors: {
body: [],
input: [],
inputType: ['Input type is required.'],
query: [],
Expand Down
Loading