Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ export const buildThreatEnrichment = ({
const signalsQueryMap = await getSignalsQueryMapFromThreatIndex({
threatSearchParams,
eventsCount: signals.length,
termsQueryAllowed: false,
});

const enrichment = threatEnrichmentFactory({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ export const createEventSignal = async ({
threatSearchParams,
eventsCount: currentEventList.length,
signalValueMap: getSignalValueMap({ eventList: currentEventList, threatMatchedFields }),
termsQueryAllowed: true,
});

const ids = Array.from(signalsQueryMap.keys());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
await getSignalsQueryMapFromThreatIndex({
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
termsQueryAllowed: false,
});

expect(getThreatListMock).toHaveBeenCalledTimes(1);
Expand All @@ -65,6 +66,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
const signalsQueryMap = await getSignalsQueryMapFromThreatIndex({
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
termsQueryAllowed: false,
});

expect(signalsQueryMap).toEqual(new Map());
Expand Down Expand Up @@ -98,6 +100,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
const signalsQueryMap = await getSignalsQueryMapFromThreatIndex({
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
termsQueryAllowed: false,
});

expect(signalsQueryMap).toEqual(
Expand Down Expand Up @@ -153,6 +156,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
const signalsQueryMap = await getSignalsQueryMapFromThreatIndex({
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
termsQueryAllowed: false,
});

expect(signalsQueryMap.get('source-1')).toHaveLength(MAX_NUMBER_OF_SIGNAL_MATCHES);
Expand All @@ -168,6 +172,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
const signalsQueryMap = await getSignalsQueryMapFromThreatIndex({
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
termsQueryAllowed: false,
});

expect(signalsQueryMap).toEqual(new Map());
Expand Down Expand Up @@ -201,6 +206,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
signalValueMap,
termsQueryAllowed: true,
});

expect(signalsQueryMap).toEqual(new Map());
Expand Down Expand Up @@ -234,6 +240,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
signalValueMap,
termsQueryAllowed: true,
});

const queries = [
Expand Down Expand Up @@ -283,6 +290,7 @@ describe('getSignalsQueryMapFromThreatIndex', () => {
threatSearchParams: threatSearchParamsMock,
eventsCount: 50,
signalValueMap,
termsQueryAllowed: true,
});

const queries = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,34 @@ import { MAX_NUMBER_OF_SIGNAL_MATCHES } from './enrich_signal_threat_matches';

export type SignalsQueryMap = Map<string, ThreatMatchNamedQuery[]>;

interface GetSignalsMatchesFromThreatIndexOptions {
interface GetSignalsQueryMapFromThreatIndexOptionsTerms {
threatSearchParams: Omit<GetThreatListOptions, 'searchAfter'>;
eventsCount: number;
signalValueMap?: SignalValuesMap;
signalValueMap: SignalValuesMap;
termsQueryAllowed: true;
}

interface GetSignalsQueryMapFromThreatIndexOptionsMatch {
threatSearchParams: Omit<GetThreatListOptions, 'searchAfter'>;
eventsCount: number;
termsQueryAllowed: false;
}

/**
* fetches threats and creates signals map from results, that matches signal is with list of threat queries
*/
export const getSignalsQueryMapFromThreatIndex = async ({
threatSearchParams,
eventsCount,
signalValueMap,
}: GetSignalsMatchesFromThreatIndexOptions): Promise<SignalsQueryMap> => {
/**
* fetches threats and creates signals map from results, that matches signal is with list of threat queries
* @param options.termsQueryAllowed - if terms query allowed to be executed, then signalValueMap should be provided
* @param options.signalValueMap - map of signal values from terms query results
*/
export async function getSignalsQueryMapFromThreatIndex(
options:
| GetSignalsQueryMapFromThreatIndexOptionsTerms
| GetSignalsQueryMapFromThreatIndexOptionsMatch
): Promise<SignalsQueryMap> {
const { threatSearchParams, eventsCount, termsQueryAllowed } = options;

let threatList: Awaited<ReturnType<typeof getThreatList>> | undefined;
const signalsQueryMap = new Map<string, ThreatMatchNamedQuery[]>();
// number of threat matches per signal is limited by MAX_NUMBER_OF_SIGNAL_MATCHES. Once it hits this number, threats stop to be processed for a signal
Expand All @@ -50,9 +64,6 @@ export const getSignalsQueryMapFromThreatIndex = async ({
decodedQuery: ThreatMatchNamedQuery | ThreatTermNamedQuery;
}) => {
const signalMatch = signalsQueryMap.get(signalId);
if (!signalMatch) {
signalsQueryMap.set(signalId, []);
}

const threatQuery = {
id: threatHit._id,
Expand All @@ -74,29 +85,23 @@ export const getSignalsQueryMapFromThreatIndex = async ({
}
};

while (
maxThreatsReachedMap.size < eventsCount &&
(threatList ? threatList?.hits.hits.length > 0 : true)
) {
threatList = await getThreatList({
...threatSearchParams,
searchAfter: threatList?.hits.hits[threatList.hits.hits.length - 1].sort || undefined,
});
threatList = await getThreatList({ ...threatSearchParams, searchAfter: undefined });

while (maxThreatsReachedMap.size < eventsCount && threatList?.hits.hits.length > 0) {
threatList.hits.hits.forEach((threatHit) => {
const matchedQueries = threatHit?.matched_queries || [];

matchedQueries.forEach((matchedQuery) => {
const decodedQuery = decodeThreatMatchNamedQuery(matchedQuery);
const signalId = decodedQuery.id;

if (decodedQuery.queryType === ThreatMatchQueryType.term) {
if (decodedQuery.queryType === ThreatMatchQueryType.term && termsQueryAllowed) {
const threatValue = get(threatHit?._source, decodedQuery.value);
const values = Array.isArray(threatValue) ? threatValue : [threatValue];

values.forEach((value) => {
if (value && signalValueMap) {
const ids = signalValueMap[decodedQuery.field][value?.toString()];
if (value && options.signalValueMap) {
const ids = options.signalValueMap[decodedQuery.field][value?.toString()];

ids?.forEach((id: string) => {
addSignalValueToMap({
Expand All @@ -120,7 +125,12 @@ export const getSignalsQueryMapFromThreatIndex = async ({
}
});
});

threatList = await getThreatList({
...threatSearchParams,
searchAfter: threatList.hits.hits[threatList.hits.hits.length - 1].sort,
});
}

return signalsQueryMap;
};
}