Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions oas_docs/output/kibana.serverless.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83989,6 +83989,9 @@ components:
reader:
$ref: '#/components/schemas/Security_AI_Assistant_API_Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/Security_AI_Assistant_API_MessageRole'
description: Message role.
Expand Down
3 changes: 3 additions & 0 deletions oas_docs/output/kibana.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94507,6 +94507,9 @@ components:
reader:
$ref: '#/components/schemas/Security_AI_Assistant_API_Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/Security_AI_Assistant_API_MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ export interface ChatCompleteResponse<
* The text content of the LLM response.
*/
content: string;
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string;
/**
* The eventual tool calls performed by the LLM.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ export type ChatCompletionMessageEvent<TToolOptions extends ToolOptions = ToolOp
* The text content of the LLM response.
*/
content: string;
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string;
/**
* Optional deanonymized input messages metadata
*/
Expand Down Expand Up @@ -84,6 +88,10 @@ export type ChatCompletionChunkEvent = InferenceTaskEventBase<
* The content chunk
*/
content: string;
/**
* Optional refusal reason chunk.
*/
refusal?: string;
/**
* The tool call chunks
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ export type AssistantMessage<TToolCalls extends ToolCall[] | undefined = ToolCal
* Can be null if the LLM called a tool.
*/
content: string | null;
/**
* Optional refusal reason returned by the model when content is filtered.
*/
refusal?: string | null;
// make sure `toolCalls` inherits the optionality from `TToolCalls`
} & (TToolCalls extends ToolCall[]
? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ export const completionChunkToLangchain = (chunk: ChatCompletionChunkEvent): AIM
};
});

const additionalKwargs = chunk.refusal ? { refusal: chunk.refusal } : {};

return new AIMessageChunk({
content: chunk.content,
tool_call_chunks: toolCallChunks,
additional_kwargs: {},
additional_kwargs: additionalKwargs,
response_metadata: {},
});
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ import type { ChatCompleteResponse } from '@kbn/inference-common';
import { AIMessage } from '@langchain/core/messages';

export const responseToLangchainMessage = (response: ChatCompleteResponse): AIMessage => {
const additionalKwargs = response.refusal ? { refusal: response.refusal } : undefined;
return new AIMessage({
content: response.content,
...(additionalKwargs ? { additional_kwargs: additionalKwargs } : {}),
tool_calls: response.toolCalls.map((toolCall) => {
return {
id: toolCall.toolCallId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,9 @@ components:
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,9 @@ components:
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
refusal:
description: Refusal reason returned by the model when content is filtered.
type: string
role:
$ref: '#/components/schemas/MessageRole'
description: Message role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ export const Message = z.object({
* Message content.
*/
content: z.string(),
/**
* Refusal reason returned by the model when content is filtered.
*/
refusal: z.string().optional(),
/**
* Message content.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,9 @@ components:
type: string
description: Message content.
example: 'Hello, how can I assist you today?'
refusal:
type: string
description: Refusal reason returned by the model when content is filtered.
reader:
$ref: '#/components/schemas/Reader'
description: Message content.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export function chunkFromOpenAI(chunk: OpenAI.ChatCompletionChunk): ChatCompleti
return {
type: ChatCompletionEventType.ChatCompletionChunk,
content: delta.content ?? '',
refusal: delta.refusal ?? undefined,
tool_calls:
delta.tool_calls?.map((toolCall) => {
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export function chunksIntoMessage<TToolOptions extends ToolOptions>({

logger.debug(() => `Received completed message: ${JSON.stringify(concatenatedChunk)}`);

const { content, tool_calls: toolCalls } = concatenatedChunk;
const { content, refusal, tool_calls: toolCalls } = concatenatedChunk;
const activeSpan = trace.getActiveSpan();
if (activeSpan) {
setChoice(activeSpan, { content, toolCalls });
Expand All @@ -59,6 +59,7 @@ export function chunksIntoMessage<TToolOptions extends ToolOptions>({
return {
type: ChatCompletionEventType.ChatCompletionMessage,
content,
...(refusal ? { refusal } : {}),
toolCalls: validatedToolCalls,
};
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import type { ChatCompletionChunkEvent, UnvalidatedToolCall } from '@kbn/inferen

interface UnvalidatedMessage {
content: string;
refusal?: string;
tool_calls: UnvalidatedToolCall[];
}

Expand All @@ -19,6 +20,9 @@ export const mergeChunks = (chunks: ChatCompletionChunkEvent[]): UnvalidatedMess
const message = chunks.reduce<UnvalidatedMessage>(
(prev, chunk) => {
prev.content += chunk.content ?? '';
if (chunk.refusal) {
prev.refusal = chunk.refusal;
}

chunk.tool_calls?.forEach((toolCall) => {
let prevToolCall = prev.tool_calls[toolCall.index];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export const streamToResponse = <TToolOptions extends ToolOptions = ToolOptions>

return {
content: messageEvent.content,
refusal: messageEvent.refusal,
toolCalls: messageEvent.toolCalls,
tokens: tokenEvent?.tokens,
deanonymized_input: messageEvent.deanonymized_input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ export function chunksIntoMessage(obs$: Observable<UnifiedChatCompleteResponse>)
(prev, chunk) => {
if (chunk.choices.length > 0 && !chunk.usage) {
prev.choices[0].message.content += chunk.choices[0].message.content ?? '';
if (chunk.choices[0].message.refusal) {
prev.choices[0].message.refusal = chunk.choices[0].message.refusal;
}

chunk.choices[0].message.tool_calls?.forEach((toolCall) => {
if (toolCall.index !== undefined) {
Expand Down Expand Up @@ -89,6 +92,7 @@ export function chunksIntoMessage(obs$: Observable<UnifiedChatCompleteResponse>)
{
message: {
content: '',
refusal: null,
role: 'assistant',
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,29 @@ describe('appendConversationMessages', () => {
);
});

it('preserves refusal reason when present on messages', async () => {
const messageWithRefusal = createMockMessage({
refusal: 'Detected harmful input content: INSULTS',
});
setupSuccessfulTest();

await callAppendConversationMessages([messageWithRefusal]);

expect(dataWriter.bulk).toHaveBeenCalledWith(
expect.objectContaining({
documentsToUpdate: expect.arrayContaining([
expect.objectContaining({
messages: expect.arrayContaining([
expect.objectContaining({
refusal: 'Detected harmful input content: INSULTS',
}),
]),
}),
]),
})
);
});

it('generates UUID for messages without id', async () => {
const messageWithoutId = createMockMessage({ id: undefined });
setupSuccessfulTest();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ export const transformToUpdateScheme = (
'@timestamp': message.timestamp,
id: message.id ?? uuidv4(),
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ export const transformToCreateScheme = (
'@timestamp': message.timestamp,
id: message.id ?? uuidv4(),
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ export const conversationsFieldMap: FieldMap = {
array: false,
required: false,
},
'messages.refusal': {
type: 'text',
array: false,
required: false,
},
'messages.reader': {
type: 'object',
array: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ export const transformESToConversation = (
messageContent: message.content,
replacements,
}),
...(message.refusal ? { refusal: message.refusal } : {}),
...(message.is_error ? { isError: message.is_error } : {}),
...(message.reader ? { reader: message.reader } : {}),
...(message.user ? { user: message.user } : {}),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export interface EsConversationSchema {
'@timestamp': string;
id?: string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down Expand Up @@ -77,6 +78,7 @@ export interface CreateMessageSchema {
'@timestamp': string;
id: string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ export interface UpdateConversationSchema {
'@timestamp': string;
id: string;
content: string;
refusal?: string;
reader?: Reader;
role: MessageRole;
is_error?: boolean;
Expand Down Expand Up @@ -136,6 +137,7 @@ export const transformToUpdateScheme = (
'@timestamp': message.timestamp,
id: message.id ?? uuidv4(),
content: message.content,
...(message.refusal ? { refusal: message.refusal } : {}),
is_error: message.isError,
reader: message.reader,
role: message.role,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import type { AIAssistantDataClient } from '../../../ai_assistant_data_clients';

export type OnLlmResponse = (args: {
content: string;
refusal?: string;
interruptValue?: InterruptValue;
traceData?: Message['traceData'];
isError?: boolean;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,14 @@ export const streamGraph = async ({

const handleFinalContent = (args: {
finalResponse: string;
refusal?: string;
isError: boolean;
interruptValue?: InterruptValue;
}) => {
if (onLlmResponse) {
onLlmResponse({
content: args.finalResponse,
refusal: args.refusal,
interruptValue: args.interruptValue,
traceData: {
transactionId: streamingSpan?.transaction?.ids?.['transaction.id'],
Expand Down Expand Up @@ -151,7 +153,11 @@ export const streamGraph = async ({
!data.output.lc_kwargs?.tool_calls?.length &&
!didEnd
) {
handleFinalContent({ finalResponse: data.output.content, isError: false });
const refusal =
typeof data.output?.additional_kwargs?.refusal === 'string'
? (data.output.additional_kwargs.refusal as string)
: undefined;
handleFinalContent({ finalResponse: data.output.content, refusal, isError: false });
} else if (
// This is the end of one model invocation but more message will follow as there are tool calls. If this chunk contains text content, add a newline separator to the stream to visually separate the chunks.
event === 'on_chat_model_end' &&
Expand Down Expand Up @@ -234,10 +240,15 @@ export const invokeGraph = async ({
const lastMessage = result.messages[result.messages.length - 1];
const output = lastMessage.text;
const conversationId = result.conversationId;
const refusal =
typeof lastMessage?.additional_kwargs?.refusal === 'string'
? (lastMessage.additional_kwargs.refusal as string)
: undefined;
if (onLlmResponse) {
await onLlmResponse({
content: output,
traceData,
...(refusal ? { refusal } : {}),
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ export const chatCompleteRoute = (

const onLlmResponse: OnLlmResponse = async ({
content,
refusal,
traceData = {},
isError = false,
}): Promise<void> => {
Expand All @@ -225,6 +226,7 @@ export const chatCompleteRoute = (
conversationId,
conversationsDataClient,
messageContent: prunedContent,
messageRefusal: refusal,
replacements: latestReplacements,
isError,
traceData,
Expand Down
Loading