Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions oas_docs/output/kibana.serverless.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -58784,6 +58784,9 @@ components:
reader:
$ref: '#/components/schemas/Security_AI_Assistant_API_Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/Security_AI_Assistant_API_MessageRole'
description: Message role.
Expand Down
3 changes: 3 additions & 0 deletions oas_docs/output/kibana.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68212,6 +68212,9 @@ components:
reader:
$ref: '#/components/schemas/Security_AI_Assistant_API_Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/Security_AI_Assistant_API_MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ export interface ChatCompleteResponse<
* The text content of the LLM response.
*/
content: string;
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string;
/**
* The eventual tool calls performed by the LLM.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions = ToolOp
* The eventual tool calls performed by the LLM.
*/
toolCalls: ToolCallsOf<TToolOptions>['toolCalls'];
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string;
/**
* Optional deanonymized input messages metadata
*/
Expand Down Expand Up @@ -83,6 +87,10 @@ export type ChatCompletionChunkEvent = InferenceTaskEventBase<
* The content chunk
*/
content: string;
/**
* Optional refusal reason chunk.
*/
refusal?: string;
/**
* The tool call chunks
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ export type AssistantMessage = MessageBase<MessageRole.Assistant> & {
* Note that LLM with parallel tool invocation can potentially call multiple tools at the same time.
*/
toolCalls?: ToolCall[];
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string | null;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ export const completionChunkToLangchain = (chunk: ChatCompletionChunkEvent): AIM
};
});

const additionalKwargs = chunk.refusal ? { refusal: chunk.refusal } : {};

return new AIMessageChunk({
content: chunk.content,
tool_call_chunks: toolCallChunks,
additional_kwargs: {},
additional_kwargs: additionalKwargs,
response_metadata: {},
});
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import type { ChatCompleteResponse } from '@kbn/inference-common';
import { AIMessage } from '@langchain/core/messages';

export const responseToLangchainMessage = (response: ChatCompleteResponse): AIMessage => {
const additionalKwargs = response.refusal ? { refusal: response.refusal } : undefined;
return new AIMessage({
content: response.content,
...(additionalKwargs ? { additional_kwargs: additionalKwargs } : {}),
tool_calls: response.toolCalls.map((toolCall) => {
return {
id: toolCall.toolCallId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2739,6 +2739,9 @@ components:
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2739,6 +2739,9 @@ components:
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ export const Message = z.object({
* Message content.
*/
content: z.string(),
/**
* Refusal reason returned by the model when content is filtered.
*/
refusal: z.string().optional(),
/**
* Message content.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,9 @@ components:
type: string
description: Message content.
example: 'Hello, how can I assist you today?'
refusal:
type: string
description: Refusal reason returned by the model when content is filtered.
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompleti
return {
type: ChatCompletionEventType.ChatCompletionChunk,
content: delta.content ?? '',
refusal: delta.refusal ?? undefined,
tool_calls:
delta.tool_calls?.map((toolCall) => {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function chunksIntoMessage<TToolOptions extends ToolOptions>({

logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`);

const { content, tool_calls: toolCalls } = concatenatedChunk;
const { content, refusal, tool_calls: toolCalls } = concatenatedChunk;
const activeSpan = trace.getActiveSpan();
if (activeSpan) {
setChoice(activeSpan, { content, toolCalls });
Expand All @@ -56,6 +56,7 @@ export function chunksIntoMessage<TToolOptions extends ToolOptions>({
return {
type: ChatCompletionEventType.ChatCompletionMessage,
content,
...(refusal ? { refusal } : {}),
toolCalls: validatedToolCalls,
};
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { ChatCompletionChunkEvent, UnvalidatedToolCall } from '@kbn/inference-co

interface UnvalidatedMessage {
content: string;
refusal?: string;
tool_calls: UnvalidatedToolCall[];
}

Expand All @@ -19,6 +20,9 @@ export const mergeChunks = (chunks: ChatCompletionChunkEvent[]): UnvalidatedMess
const message = chunks.reduce<UnvalidatedMessage>(
(prev, chunk) => {
prev.content += chunk.content ?? '';
if (chunk.refusal) {
prev.refusal = chunk.refusal;
}

chunk.tool_calls?.forEach((toolCall) => {
let prevToolCall = prev.tool_calls[toolCall.index];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export const streamToResponse = <TToolOptions extends ToolOptions = ToolOptions>

return {
content: messageEvent.content,
refusal: messageEvent.refusal,
toolCalls: messageEvent.toolCalls,
tokens: tokenEvent?.tokens,
deanonymized_input: messageEvent.deanonymized_input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ export function chunksIntoMessage(obs$: Observable<UnifiedChatCompleteResponse>)
(prev, chunk) => {
if (chunk.choices.length > 0 && !chunk.usage) {
prev.choices[0].message.content += chunk.choices[0].message.content ?? '';
if (chunk.choices[0].message.refusal) {
prev.choices[0].message.refusal = chunk.choices[0].message.refusal;
}

chunk.choices[0].message.tool_calls?.forEach((toolCall) => {
if (toolCall.index !== undefined) {
Expand Down Expand Up @@ -89,6 +92,7 @@ export function chunksIntoMessage(obs$: Observable<UnifiedChatCompleteResponse>)
{
message: {
content: '',
refusal: null,
role: 'assistant',
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,29 @@ describe('appendConversationMessages', () => {
})
);
});

it('preserves refusal reason when present on messages', async () => {
const messageWithRefusal = createMockMessage({
refusal: 'Detected harmful input content: INSULTS',
});
setupSuccessfulTest();

await callAppendConversationMessages([messageWithRefusal]);

expect(dataWriter.bulk).toHaveBeenCalledWith(
expect.objectContaining({
documentsToUpdate: expect.arrayContaining([
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
refusal: 'Detected harmful input content: INSULTS',
}),
]),
}),
]),
})
);
});
});

describe('transformToUpdateScheme', () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ export const transformToUpdateScheme = (
messages: messages?.map((message) => ({
'@timestamp': message.timestamp,
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ export const transformToCreateScheme = (
messages: messages?.map((message) => ({
'@timestamp': message.timestamp,
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ export const conversationsFieldMap: FieldMap = {
array: false,
required: false,
},
'messages.refusal': {
type: 'text',
array: false,
required: false,
},
'messages.reader': {
type: 'object',
array: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ export const transformESToConversation = (
messageContent: message.content,
replacements,
}),
...(message.refusal ? { refusal: message.refusal } : {}),
...(message.is_error ? { isError: message.is_error } : {}),
...(message.reader ? { reader: message.reader } : {}),
...(message.user ? { user: message.user } : {}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export interface EsConversationSchema {
messages?: Array<{
'@timestamp': string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down Expand Up @@ -70,6 +71,7 @@ export interface CreateMessageSchema {
messages?: Array<{
'@timestamp': string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ export interface UpdateConversationSchema {
messages?: Array<{
'@timestamp': string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down Expand Up @@ -133,6 +134,7 @@ export const transformToUpdateScheme = (
messages: messages.map((message) => ({
'@timestamp': message.timestamp,
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import { AIAssistantDataClient } from '../../../ai_assistant_data_clients';
export type OnLlmResponse = (
content: string,
traceData?: Message['traceData'],
isError?: boolean
isError?: boolean,
refusal?: string
) => Promise<void>;

export interface AssistantDataClients {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ describe('streamGraph', () => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'final message',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
false,
undefined
);
});
});
Expand Down Expand Up @@ -177,7 +178,8 @@ describe('streamGraph', () => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'content',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
false,
undefined
);
});
});
Expand Down Expand Up @@ -239,7 +241,8 @@ describe('streamGraph', () => {
expect(mockOnLlmResponse).toHaveBeenCalledWith(
'Look at these rare IP addresses.',
{ transactionId: 'transactionId', traceId: 'traceId' },
false
false,
undefined
);
});
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ export const streamGraph = async ({
} = streamFactory<{ type: string; payload: string }>(request.headers, logger, false, false);

let didEnd = false;
const handleStreamEnd = (finalResponse: string, isError = false) => {
const handleStreamEnd = (finalResponse: string, isError = false, refusal?: string) => {
if (didEnd) {
return;
}
Expand All @@ -92,7 +92,8 @@ export const streamGraph = async ({
transactionId: streamingSpan?.transaction?.ids?.['transaction.id'],
traceId: streamingSpan?.ids?.['trace.id'],
},
isError
isError,
refusal
).catch(() => {});
}
streamEnd();
Expand Down Expand Up @@ -129,7 +130,11 @@ export const streamGraph = async ({
!data.output.lc_kwargs?.tool_calls?.length &&
!didEnd
) {
handleStreamEnd(data.output.content);
const refusal =
typeof data.output?.additional_kwargs?.refusal === 'string'
? (data.output.additional_kwargs.refusal as string)
: undefined;
handleStreamEnd(data.output.content, false, refusal);
} else if (
// This is the end of one model invocation but more message will follow as there are tool calls. If this chunk contains text content, add a newline separator to the stream to visually separate the chunks.
event === 'on_chat_model_end' &&
Expand Down Expand Up @@ -206,8 +211,12 @@ export const invokeGraph = async ({
const lastMessage = result.messages[result.messages.length - 1];
const output = lastMessage.text;
const conversationId = result.conversationId;
const refusal =
typeof lastMessage?.additional_kwargs?.refusal === 'string'
? (lastMessage.additional_kwargs.refusal as string)
: undefined;
if (onLlmResponse) {
await onLlmResponse(output, traceData);
await onLlmResponse(output, traceData, false, refusal);
}

return { output, traceData, conversationId };
Expand Down
Loading