Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

export const ecsMappingExpectedResults = {
mapping: {
mysql_enterprise: {
Expand Down Expand Up @@ -441,18 +442,21 @@ export const ecsTestState = {
ecs: 'teststring',
exAnswer: 'testanswer',
finalized: false,
chunkSize: 30,
currentPipeline: { test: 'testpipeline' },
duplicateFields: [],
missingKeys: [],
invalidEcsFields: [],
finalMapping: { test: 'testmapping' },
sampleChunks: [''],
results: { test: 'testresults' },
samplesFormat: 'testsamplesFormat',
ecsVersion: 'testversion',
currentMapping: { test1: 'test1' },
lastExecutedChain: 'testchain',
rawSamples: ['{"test1": "test1"}'],
samples: ['{ "test1": "test1" }'],
prefixedSamples: ['{ "test1": "test1" }'],
packageName: 'testpackage',
dataStreamName: 'testDataStream',
formattedSamples: '{"test1": "test1"}',
combinedSamples: '{"test1": "test1"}',
};
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import type {
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { CategorizationState } from '../../types';
import { modifySamples, formatSamples } from '../../util/samples';
import { prefixSamples, formatSamples } from '../../util/samples';
import { handleCategorization } from './categorization';
import { handleValidatePipeline } from '../../util/graph';
import { handleCategorizationValidation } from './validate';
Expand Down Expand Up @@ -106,7 +106,7 @@ const graphState: StateGraphArgs<CategorizationState>['channels'] = {
};

function modelInput(state: CategorizationState): Partial<CategorizationState> {
const samples = modifySamples(state);
const samples = prefixSamples(state);
const formattedSamples = formatSamples(samples);
const initialPipeline = JSON.parse(JSON.stringify(state.currentPipeline));
return {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { mergeAndChunkSamples } from './chunk';

describe('test chunks', () => {
it('mergeAndChunkSamples()', async () => {
const objects = ['{"a": 1, "b": 2, "c": {"d": 3}}', '{"a": 2, "b": 3, "e": 4}'];
const chunkSize = 2;
const result = mergeAndChunkSamples(objects, chunkSize);
expect(result).toEqual(['{"a":1,"b":2}', '{"c":{"d":3},"e":4}']);
});
});
83 changes: 83 additions & 0 deletions x-pack/plugins/integration_assistant/server/graphs/ecs/chunk.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
/* eslint-disable @typescript-eslint/no-explicit-any */

import { merge } from '../../util/samples';

interface NestedObject {
[key: string]: any;
}

// Takes an array of JSON strings and merges them into a single object.
// The resulting object will be a combined object that includes all unique fields from the input samples.
// While merging the samples, the function will prioritize non-empty values over empty values.
// The function then splits the combined object into chunks of a given size, to be used in the ECS mapping subgraph.
export function mergeAndChunkSamples(objects: string[], chunkSize: number): string[] {
let result: NestedObject = {};

for (const obj of objects) {
const sample: NestedObject = JSON.parse(obj);
result = merge(result, sample);
}

const chunks = generateChunks(result, chunkSize);

// Each chunk is used for the combinedSamples state when passed to the subgraph, which should be a nicely formatted string
return chunks.map((chunk) => JSON.stringify(chunk));
}

// This function takes the already merged array of samples, and splits it up into chunks of a given size.
// Size is determined by the count of fields with an actual value (not nested objects etc).
// This is to be able to run the ECS mapping sub graph concurrently with a larger number of total unique fields without getting confused.
function generateChunks(mergedSamples: NestedObject, chunkSize: number): NestedObject[] {
const chunks: NestedObject[] = [];
let currentChunk: NestedObject = {};
let currentSize = 0;

function traverse(current: NestedObject, path: string[] = []) {
for (const [key, value] of Object.entries(current)) {
const newPath = [...path, key];

// If the value is a nested object, recurse into it
if (typeof value === 'object' && value !== null && !Array.isArray(value)) {
traverse(value, newPath);
} else {
// For non-object values, add them to the current chunk
let target = currentChunk;

// Recreate the nested structure in the current chunk
for (let i = 0; i < newPath.length - 1; i++) {
if (!(newPath[i] in target)) {
target[newPath[i]] = {};
}
target = target[newPath[i]];
}

// Add the value to the deepest level of the structure
target[newPath[newPath.length - 1]] = value;
currentSize++;

// If the chunk is full, add it to the chunks and start a new chunk
if (currentSize === chunkSize) {
chunks.push(currentChunk);
currentChunk = {};
currentSize = 0;
}
}
}
}

// Start the traversal from the root object
traverse(mergedSamples);

// Add any remaining items in the last chunk
if (currentSize > 0) {
chunks.push(currentChunk);
}

return chunks;
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ export async function handleDuplicates(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
const ecsDuplicatesPrompt = ECS_DUPLICATES_PROMPT;
const outputParser = new JsonOutputParser();
const ecsDuplicatesGraph = ecsDuplicatesPrompt.pipe(model).pipe(outputParser);
const ecsDuplicatesGraph = ECS_DUPLICATES_PROMPT.pipe(model).pipe(outputParser);

const currentMapping = await ecsDuplicatesGraph.invoke({
ecs: state.ecs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ describe('EcsGraph', () => {
it('Runs the whole graph, with mocked outputs from the LLM.', async () => {
// The mocked outputs are specifically crafted to trigger ALL different conditions, allowing us to test the whole graph.
// This is why we have all the expects ensuring each function was called.

const ecsGraph = await getEcsGraph(mockLlm);
const response = await ecsGraph.invoke(mockedRequest);
expect(response.results).toStrictEqual(ecsMappingExpectedResults);
Expand Down
158 changes: 42 additions & 116 deletions x-pack/plugins/integration_assistant/server/graphs/ecs/graph.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,120 +9,32 @@ import type {
ActionsClientChatOpenAI,
ActionsClientSimpleChatModel,
} from '@kbn/langchain/server/language_models';
import type { StateGraphArgs } from '@langchain/langgraph';
import { END, START, StateGraph } from '@langchain/langgraph';
import { END, START, StateGraph, Send } from '@langchain/langgraph';
import type { EcsMappingState } from '../../types';
import { mergeSamples, modifySamples } from '../../util/samples';
import { ECS_EXAMPLE_ANSWER, ECS_FIELDS } from './constants';
import { modelInput, modelOutput, modelSubOutput } from './model';
import { handleDuplicates } from './duplicates';
import { handleInvalidEcs } from './invalid';
import { handleEcsMapping } from './mapping';
import { handleMissingKeys } from './missing';
import { createPipeline } from './pipeline';
import { handleValidateMappings } from './validate';
import { graphState } from './state';

const graphState: StateGraphArgs<EcsMappingState>['channels'] = {
ecs: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
lastExecutedChain: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
rawSamples: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
samples: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
formattedSamples: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
exAnswer: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
packageName: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
dataStreamName: {
value: (x: string, y?: string) => y ?? x,
default: () => '',
},
finalized: {
value: (x: boolean, y?: boolean) => y ?? x,
default: () => false,
},
currentMapping: {
value: (x: object, y?: object) => y ?? x,
default: () => ({}),
},
currentPipeline: {
value: (x: object, y?: object) => y ?? x,
default: () => ({}),
},
duplicateFields: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
missingKeys: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
invalidEcsFields: {
value: (x: string[], y?: string[]) => y ?? x,
default: () => [],
},
results: {
value: (x: object, y?: object) => y ?? x,
default: () => ({}),
},
samplesFormat: {
value: (x: string, y?: string) => y ?? x,
default: () => 'json',
},
ecsVersion: {
value: (x: string, y?: string) => y ?? x,
default: () => '8.11.0',
},
};

function modelInput(state: EcsMappingState): Partial<EcsMappingState> {
const samples = modifySamples(state);
const formattedSamples = mergeSamples(samples);
return {
exAnswer: JSON.stringify(ECS_EXAMPLE_ANSWER, null, 2),
ecs: JSON.stringify(ECS_FIELDS, null, 2),
samples,
finalized: false,
formattedSamples,
lastExecutedChain: 'modelInput',
};
}

function modelOutput(state: EcsMappingState): Partial<EcsMappingState> {
const currentPipeline = createPipeline(state);
return {
finalized: true,
lastExecutedChain: 'modelOutput',
results: {
mapping: state.currentMapping,
pipeline: currentPipeline,
},
const handleCreateMappingChunks = async (state: EcsMappingState) => {
// Cherrypick a shallow copy of state to pass to subgraph
const stateParams = {
exAnswer: state.exAnswer,
prefixedSamples: state.prefixedSamples,
ecs: state.ecs,
dataStreamName: state.dataStreamName,
packageName: state.packageName,
};
}

function inputRouter(state: EcsMappingState): string {
if (Object.keys(state.currentMapping).length === 0) {
return 'ecsMapping';
return state.sampleChunks.map((chunk) => {
return new Send('subGraph', { ...stateParams, combinedSamples: chunk });
});
}
return 'modelOutput';
}
};

function chainRouter(state: EcsMappingState): string {
if (Object.keys(state.duplicateFields).length > 0) {
Expand All @@ -135,38 +47,52 @@ function chainRouter(state: EcsMappingState): string {
return 'invalidEcsFields';
}
if (!state.finalized) {
return 'modelOutput';
return 'modelSubOutput';
}
return END;
}

export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
// This is added as a separate graph to be able to run these steps concurrently from handleCreateMappingChunks
async function getEcsSubGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping(state, model))
.addNode('modelSubOutput', modelSubOutput)
.addNode('handleValidation', handleValidateMappings)
.addNode('handleEcsMapping', (state: EcsMappingState) => handleEcsMapping(state, model))
.addNode('handleDuplicates', (state: EcsMappingState) => handleDuplicates(state, model))
.addNode('handleMissingKeys', (state: EcsMappingState) => handleMissingKeys(state, model))
.addNode('handleInvalidEcs', (state: EcsMappingState) => handleInvalidEcs(state, model))
.addEdge(START, 'modelInput')
.addEdge('modelOutput', END)
.addEdge(START, 'handleEcsMapping')
.addEdge('handleEcsMapping', 'handleValidation')
.addEdge('handleDuplicates', 'handleValidation')
.addEdge('handleMissingKeys', 'handleValidation')
.addEdge('handleInvalidEcs', 'handleValidation')
.addConditionalEdges('modelInput', inputRouter, {
ecsMapping: 'handleEcsMapping',
modelOutput: 'modelOutput',
})
.addConditionalEdges('handleValidation', chainRouter, {
duplicateFields: 'handleDuplicates',
missingKeys: 'handleMissingKeys',
invalidEcsFields: 'handleInvalidEcs',
modelOutput: 'modelOutput',
});
modelSubOutput: 'modelSubOutput',
})
.addEdge('modelSubOutput', END);

const compiledEcsSubGraph = workflow.compile();

return compiledEcsSubGraph;
}

export async function getEcsGraph(model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel) {
const subGraph = await getEcsSubGraph(model);
const workflow = new StateGraph({
channels: graphState,
})
.addNode('modelInput', modelInput)
.addNode('modelOutput', modelOutput)
.addNode('subGraph', subGraph)
.addEdge(START, 'modelInput')
.addEdge('subGraph', 'modelOutput')
.addConditionalEdges('modelInput', handleCreateMappingChunks)
.addEdge('modelOutput', END);

const compiledEcsGraph = workflow.compile();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ export async function handleInvalidEcs(
state: EcsMappingState,
model: ActionsClientChatOpenAI | ActionsClientSimpleChatModel
) {
const ecsInvalidEcsPrompt = ECS_INVALID_PROMPT;
const outputParser = new JsonOutputParser();
const ecsInvalidEcsGraph = ecsInvalidEcsPrompt.pipe(model).pipe(outputParser);
const ecsInvalidEcsGraph = ECS_INVALID_PROMPT.pipe(model).pipe(outputParser);

const currentMapping = await ecsInvalidEcsGraph.invoke({
ecs: state.ecs,
current_mapping: JSON.stringify(state.currentMapping, null, 2),
ex_answer: state.exAnswer,
formatted_samples: state.formattedSamples,
combined_samples: state.combinedSamples,
invalid_ecs_fields: state.invalidEcsFields,
});

Expand Down
Loading