Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security Solution] DetectionRulesClient: return RuleResponse from all methods #186179

Merged
merged 8 commits into from
Jun 21, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import type { SecuritySolutionPluginRouter } from '../../../../../types';
import { buildRouteValidation } from '../../../../../utils/build_validation/route_validation';
import type { PromisePoolError } from '../../../../../utils/promise_pool';
import { buildSiemResponse } from '../../../routes/utils';
import { internalRuleToAPIResponse } from '../../../rule_management/normalization/rule_converters';
import { aggregatePrebuiltRuleErrors } from '../../logic/aggregate_prebuilt_rule_errors';
import { performTimelinesInstallation } from '../../logic/perform_timelines_installation';
import { createPrebuiltRuleAssetsClient } from '../../logic/rule_assets/prebuilt_rule_assets_client';
Expand Down Expand Up @@ -182,7 +181,7 @@ export const performRuleUpgradeRoute = (router: SecuritySolutionPluginRouter) =>
failed: ruleErrors.length,
},
results: {
updated: updatedRules.map(({ result }) => internalRuleToAPIResponse(result)),
updated: updatedRules.map(({ result }) => result),
skipped: skippedRules,
},
errors: allErrors,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { serverMock, requestContextMock, requestMock } from '../../../../routes/__mocks__';
import {
getRulesSchemaMock,
getRulesMlSchemaMock,
} from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { bulkPatchRulesRoute } from './route';
import { getCreateRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/mocks';
import { getMlRuleParams, getQueryRuleParams } from '../../../../rule_schema/mocks';
Expand All @@ -34,7 +38,7 @@ describe('Bulk patch rules route', () => {

clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // rule exists
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // update succeeds
clients.detectionRulesClient.patchRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.patchRule.mockResolvedValue(getRulesSchemaMock());

bulkPatchRulesRoute(server.router, logger);
});
Expand Down Expand Up @@ -72,14 +76,11 @@ describe('Bulk patch rules route', () => {
...getFindResultWithSingleHit(),
data: [getRuleMock(getMlRuleParams())],
});
clients.detectionRulesClient.patchRule.mockResolvedValueOnce(
getRuleMock(
getMlRuleParams({
anomalyThreshold,
machineLearningJobId: [machineLearningJobId],
})
)
);
clients.detectionRulesClient.patchRule.mockResolvedValueOnce({
...getRulesMlSchemaMock(),
anomaly_threshold: anomalyThreshold,
machine_learning_job_id: [machineLearningJobId],
});

const request = requestMock.create({
method: 'patch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
import type { SecuritySolutionPluginRouter } from '../../../../../../types';
import { transformBulkError, buildSiemResponse } from '../../../../routes/utils';
import { getIdBulkError } from '../../../utils/utils';
import { transformValidateBulkError } from '../../../utils/validate';
import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { getDeprecatedBulkEndpointHeader, logDeprecatedBulkEndpoint } from '../../deprecation';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
Expand Down Expand Up @@ -86,11 +85,11 @@ export const bulkPatchRulesRoute = (router: SecuritySolutionPluginRouter, logger
ruleId: payloadRule.id,
});

const rule = await detectionRulesClient.patchRule({
const patchedRule = await detectionRulesClient.patchRule({
nextParams: payloadRule,
});

return transformValidateBulkError(rule.id, rule);
return patchedRule;
} catch (err) {
return transformBulkError(idOrRuleIdOrUnknown, err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { serverMock, requestContextMock, requestMock } from '../../../../routes/__mocks__';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { bulkUpdateRulesRoute } from './route';
import type { BulkError } from '../../../../routes/utils';
import { getCreateRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/mocks';
Expand All @@ -32,7 +33,7 @@ describe('Bulk update rules route', () => {

clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit());
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRulesSchemaMock());
clients.appClient.getSignalsIndex.mockReturnValue('.siem-signals-test-index');

bulkUpdateRulesRoute(server.router, logger);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import {
import type { SecuritySolutionPluginRouter } from '../../../../../../types';
import { DETECTION_ENGINE_RULES_BULK_UPDATE } from '../../../../../../../common/constants';
import { getIdBulkError } from '../../../utils/utils';
import { transformValidateBulkError } from '../../../utils/validate';
import {
transformBulkError,
buildSiemResponse,
Expand Down Expand Up @@ -97,11 +96,11 @@ export const bulkUpdateRulesRoute = (router: SecuritySolutionPluginRouter, logge
ruleId: payloadRule.id,
});

const rule = await detectionRulesClient.updateRule({
const updatedRule = await detectionRulesClient.updateRule({
ruleUpdate: payloadRule,
});

return transformValidateBulkError(rule.id, rule);
return updatedRule;
} catch (err) {
return transformBulkError(idOrRuleIdOrUnknown, err);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import {
ruleIdsToNdJsonString,
rulesToNdJsonString,
} from '../../../../../../../common/api/detection_engine/rule_management/mocks';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';

import type { requestMock } from '../../../../routes/__mocks__';
import { createMockConfig, requestContextMock, serverMock } from '../../../../routes/__mocks__';
Expand Down Expand Up @@ -47,7 +48,8 @@ describe('Import rules route', () => {

clients.rulesClient.find.mockResolvedValue(getEmptyFindResult()); // no extant rules
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.importRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.createCustomRule.mockResolvedValue(getRulesSchemaMock());
clients.detectionRulesClient.importRule.mockResolvedValue(getRulesSchemaMock());
clients.actionsClient.getAll.mockResolvedValue([]);
context.core.elasticsearch.client.asCurrentUser.search.mockResolvedValue(
elasticsearchClientMock.createSuccessTransportRequestPromise(getBasicEmptySearchResponse())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import {

import { getMlRuleParams, getQueryRuleParams } from '../../../../rule_schema/mocks';

import {
getRulesSchemaMock,
getRulesMlSchemaMock,
} from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';

import { patchRuleRoute } from './route';
import { HttpAuthzError } from '../../../../../machine_learning/validation';

Expand All @@ -34,7 +39,7 @@ describe('Patch rule route', () => {
clients.rulesClient.get.mockResolvedValue(getRuleMock(getQueryRuleParams())); // existing rule
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // existing rule
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // successful update
clients.detectionRulesClient.patchRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.patchRule.mockResolvedValue(getRulesSchemaMock());

patchRuleRoute(server.router);
});
Expand Down Expand Up @@ -99,14 +104,11 @@ describe('Patch rule route', () => {

const anomalyThreshold = 4;
const machineLearningJobId = 'some_job_id';
clients.detectionRulesClient.patchRule.mockResolvedValueOnce(
getRuleMock(
getMlRuleParams({
anomalyThreshold,
machineLearningJobId: [machineLearningJobId],
})
)
);
clients.detectionRulesClient.patchRule.mockResolvedValueOnce({
...getRulesMlSchemaMock(),
anomaly_threshold: anomalyThreshold,
machine_learning_job_id: [machineLearningJobId],
});

const request = requestMock.create({
method: 'patch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { checkDefaultRuleExceptionListReferences } from '../../../logic/exceptions/check_for_default_rule_exception_list';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
import { getIdError } from '../../../utils/utils';
import { transformValidate } from '../../../utils/validate';

export const patchRuleRoute = (router: SecuritySolutionPluginRouter) => {
router.versioned
Expand Down Expand Up @@ -76,12 +75,12 @@ export const patchRuleRoute = (router: SecuritySolutionPluginRouter) => {
ruleId: params.id,
});

const rule = await detectionRulesClient.patchRule({
const patchedRule = await detectionRulesClient.patchRule({
nextParams: params,
});

return response.ok({
body: transformValidate(rule),
body: patchedRule,
});
} catch (err) {
const error = transformError(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import {
typicalMlRulePayload,
} from '../../../../routes/__mocks__/request_responses';
import { requestContextMock, serverMock, requestMock } from '../../../../routes/__mocks__';
import { getRulesSchemaMock } from '../../../../../../../common/api/detection_engine/model/rule_schema/rule_response_schema.mock';
import { DETECTION_ENGINE_RULES_URL } from '../../../../../../../common/constants';
import { updateRuleRoute } from './route';
import {
Expand All @@ -34,7 +35,7 @@ describe('Update rule route', () => {
clients.rulesClient.get.mockResolvedValue(getRuleMock(getQueryRuleParams())); // existing rule
clients.rulesClient.find.mockResolvedValue(getFindResultWithSingleHit()); // rule exists
clients.rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams())); // successful update
clients.detectionRulesClient.updateRule.mockResolvedValue(getRuleMock(getQueryRuleParams()));
clients.detectionRulesClient.updateRule.mockResolvedValue(getRulesSchemaMock());
clients.appClient.getSignalsIndex.mockReturnValue('.siem-signals-test-index');

updateRuleRoute(server.router);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { readRules } from '../../../logic/detection_rules_client/read_rules';
import { checkDefaultRuleExceptionListReferences } from '../../../logic/exceptions/check_for_default_rule_exception_list';
import { validateRuleDefaultExceptionList } from '../../../logic/exceptions/validate_rule_default_exception_list';
import { getIdError } from '../../../utils/utils';
import { transformValidate, validateResponseActionsPermissions } from '../../../utils/validate';
import { validateResponseActionsPermissions } from '../../../utils/validate';

export const updateRuleRoute = (router: SecuritySolutionPluginRouter) => {
router.versioned
Expand Down Expand Up @@ -80,12 +80,12 @@ export const updateRuleRoute = (router: SecuritySolutionPluginRouter) => {
existingRule
);

const rule = await detectionRulesClient.updateRule({
const updatedRule = await detectionRulesClient.updateRule({
ruleUpdate: request.body,
});

return response.ok({
body: transformValidate(rule),
body: updatedRule,
});
} catch (err) {
const error = transformError(err);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ import { buildMlAuthz } from '../../../../machine_learning/authz';
import { throwAuthzError } from '../../../../machine_learning/validation';
import { createDetectionRulesClient } from './detection_rules_client';
import type { IDetectionRulesClient } from './detection_rules_client_interface';
import { RuleResponseValidationError } from './utils';
import type { RuleAlertType } from '../../../rule_schema';

jest.mock('../../../../machine_learning/authz');
jest.mock('../../../../machine_learning/validation');
Expand Down Expand Up @@ -70,20 +68,6 @@ describe('DetectionRulesClient.createCustomRule', () => {
expect(rulesClient.create).not.toHaveBeenCalled();
});

it('throws if RuleResponse validation fails', async () => {
const internalRuleMock: RuleAlertType = getRuleMock({
...getQueryRuleParams(),
/* Casting as 'query' suppress to TS error */
type: 'fake-non-existent-type' as 'query',
});

rulesClient.create.mockResolvedValueOnce(internalRuleMock);

await expect(
detectionRulesClient.createCustomRule({ params: getCreateMachineLearningRulesSchemaMock() })
).rejects.toThrow(RuleResponseValidationError);
});

it('calls the rulesClient with legacy ML params', async () => {
await detectionRulesClient.createCustomRule({
params: getCreateMachineLearningRulesSchemaMock(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ describe('DetectionRulesClient.importRule', () => {

beforeEach(() => {
rulesClient = rulesClientMock.create();
rulesClient.create.mockResolvedValue(getRuleMock(getQueryRuleParams()));
rulesClient.update.mockResolvedValue(getRuleMock(getQueryRuleParams()));
detectionRulesClient = createDetectionRulesClient(rulesClient, mlAuthz);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import type { RulesClient } from '@kbn/alerting-plugin/server';
import type { MlAuthz } from '../../../../machine_learning/authz';

import type { RuleAlertType } from '../../../rule_schema';
import type { RuleResponse } from '../../../../../../common/api/detection_engine/model/rule_schema';
import type {
IDetectionRulesClient,
Expand Down Expand Up @@ -47,13 +46,13 @@ export const createDetectionRulesClient = (
});
},

async updateRule(args: UpdateRuleArgs): Promise<RuleAlertType> {
async updateRule(args: UpdateRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.updateRule', async () => {
return updateRule(rulesClient, args, mlAuthz);
});
},

async patchRule(args: PatchRuleArgs): Promise<RuleAlertType> {
async patchRule(args: PatchRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.patchRule', async () => {
return patchRule(rulesClient, args, mlAuthz);
});
Expand All @@ -65,13 +64,13 @@ export const createDetectionRulesClient = (
});
},

async upgradePrebuiltRule(args: UpgradePrebuiltRuleArgs): Promise<RuleAlertType> {
async upgradePrebuiltRule(args: UpgradePrebuiltRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.upgradePrebuiltRule', async () => {
return upgradePrebuiltRule(rulesClient, args, mlAuthz);
});
},

async importRule(args: ImportRuleArgs): Promise<RuleAlertType> {
async importRule(args: ImportRuleArgs): Promise<RuleResponse> {
return withSecuritySpan('DetectionRulesClient.importRule', async () => {
return importRule(rulesClient, args, mlAuthz);
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,12 @@ describe('DetectionRulesClient.upgradePrebuiltRule', () => {
ruleId: 'rule-id',
});
beforeEach(() => {
jest.resetAllMocks();
rulesClient.create.mockResolvedValue(getRuleMock(getQueryRuleParams()));
(readRules as jest.Mock).mockResolvedValue(installedRule);
});

it('deletes the old rule ', async () => {
it('deletes the old rule', async () => {
await detectionRulesClient.upgradePrebuiltRule({ ruleAsset });
expect(rulesClient.delete).toHaveBeenCalled();
});
Expand Down Expand Up @@ -153,6 +155,8 @@ describe('DetectionRulesClient.upgradePrebuiltRule', () => {
});

it('patches the existing rule with the new params from the rule asset', async () => {
rulesClient.update.mockResolvedValue(getRuleMock(getEqlRuleParams()));

await detectionRulesClient.upgradePrebuiltRule({ ruleAsset });
expect(rulesClient.update).toHaveBeenCalledWith(
expect.objectContaining({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@ import type {
RuleToImport,
RuleResponse,
} from '../../../../../../common/api/detection_engine';
import type { RuleAlertType } from '../../../rule_schema';
import type { PrebuiltRuleAsset } from '../../../prebuilt_rules';

export interface IDetectionRulesClient {
createCustomRule: (args: CreateCustomRuleArgs) => Promise<RuleResponse>;
createPrebuiltRule: (args: CreatePrebuiltRuleArgs) => Promise<RuleResponse>;
updateRule: (args: UpdateRuleArgs) => Promise<RuleAlertType>;
patchRule: (args: PatchRuleArgs) => Promise<RuleAlertType>;
updateRule: (args: UpdateRuleArgs) => Promise<RuleResponse>;
patchRule: (args: PatchRuleArgs) => Promise<RuleResponse>;
deleteRule: (args: DeleteRuleArgs) => Promise<void>;
upgradePrebuiltRule: (args: UpgradePrebuiltRuleArgs) => Promise<RuleAlertType>;
importRule: (args: ImportRuleArgs) => Promise<RuleAlertType>;
upgradePrebuiltRule: (args: UpgradePrebuiltRuleArgs) => Promise<RuleResponse>;
importRule: (args: ImportRuleArgs) => Promise<RuleResponse>;
}

export interface CreateCustomRuleArgs {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ import type { CreateCustomRuleArgs } from '../detection_rules_client_interface';
import type { MlAuthz } from '../../../../../machine_learning/authz';
import type { RuleParams } from '../../../../rule_schema';
import { RuleResponse } from '../../../../../../../common/api/detection_engine/model/rule_schema';
import { convertCreateAPIToInternalSchema } from '../../../normalization/rule_converters';
import { transform } from '../../../utils/utils';
import {
convertCreateAPIToInternalSchema,
internalRuleToAPIResponse,
} from '../../../normalization/rule_converters';
import { validateMlAuth, RuleResponseValidationError } from '../utils';

export const createCustomRule = async (
Expand All @@ -29,7 +31,7 @@ export const createCustomRule = async (
});

/* Trying to convert the rule to a RuleResponse object */
const parseResult = RuleResponse.safeParse(transform(rule));
const parseResult = RuleResponse.safeParse(internalRuleToAPIResponse(rule));

if (!parseResult.success) {
throw new RuleResponseValidationError({
Expand Down
Loading