diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts index 538d8016a0a73..b2616ed7615ba 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.test.ts @@ -17,6 +17,7 @@ import { getMlModelTypesForModelConfig, getSetProcessorForInferenceType, SUPPORTED_PYTORCH_TASKS as LOCAL_SUPPORTED_PYTORCH_TASKS, + parseMlInferenceParametersFromPipeline, } from '.'; const mockModel: MlTrainedModelConfig = { @@ -198,3 +199,45 @@ describe('generateMlInferencePipelineBody lib function', () => { ); }); }); + +describe('parseMlInferenceParametersFromPipeline', () => { + it('returns pipeline parameters from ingest pipeline', () => { + expect( + parseMlInferenceParametersFromPipeline('unit-test', { + processors: [ + { + inference: { + field_map: { + body: 'text_field', + }, + model_id: 'test-model', + target_field: 'ml.inference.test', + }, + }, + ], + }) + ).toEqual({ + destination_field: 'test', + model_id: 'test-model', + pipeline_name: 'unit-test', + source_field: 'body', + }); + }); + it('return null if pipeline missing inference processor', () => { + expect(parseMlInferenceParametersFromPipeline('unit-test', { processors: [] })).toBeNull(); + }); + it('return null if pipeline missing field_map', () => { + expect( + parseMlInferenceParametersFromPipeline('unit-test', { + processors: [ + { + inference: { + model_id: 'test-model', + target_field: 'test', + }, + }, + ], + }) + ).toBeNull(); + }); +}); diff --git a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts index b5b4526d1723b..4e5b124f8dff0 100644 --- a/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts +++ b/x-pack/plugins/enterprise_search/common/ml_inference_pipeline/index.ts @@ -5,9 +5,13 @@ * 2.0. */ -import { IngestSetProcessor, MlTrainedModelConfig } from '@elastic/elasticsearch/lib/api/types'; +import { + IngestPipeline, + IngestSetProcessor, + MlTrainedModelConfig, +} from '@elastic/elasticsearch/lib/api/types'; -import { MlInferencePipeline } from '../types/pipelines'; +import { MlInferencePipeline, CreateMlInferencePipelineParameters } from '../types/pipelines'; // Getting an error importing this from @kbn/ml-plugin/common/constants/data_frame_analytics' // So defining it locally for now with a test to make sure it matches. @@ -151,3 +155,25 @@ export const formatPipelineName = (rawName: string) => .trim() .replace(/\s+/g, '_') // Convert whitespaces to underscores .toLowerCase(); + +export const parseMlInferenceParametersFromPipeline = ( + name: string, + pipeline: IngestPipeline +): CreateMlInferencePipelineParameters | null => { + const processor = pipeline?.processors?.find((proc) => proc.inference !== undefined); + if (!processor || processor?.inference === undefined) { + return null; + } + const { inference: inferenceProcessor } = processor; + const sourceFields = Object.keys(inferenceProcessor.field_map ?? {}); + const sourceField = sourceFields.length === 1 ? sourceFields[0] : null; + if (!sourceField) { + return null; + } + return { + destination_field: inferenceProcessor.target_field.replace('ml.inference.', ''), + model_id: inferenceProcessor.model_id, + pipeline_name: name, + source_field: sourceField, + }; +}; diff --git a/x-pack/plugins/enterprise_search/common/types/pipelines.ts b/x-pack/plugins/enterprise_search/common/types/pipelines.ts index 9b53e98d584d7..38314f6d162de 100644 --- a/x-pack/plugins/enterprise_search/common/types/pipelines.ts +++ b/x-pack/plugins/enterprise_search/common/types/pipelines.ts @@ -64,3 +64,10 @@ export interface DeleteMlInferencePipelineResponse { deleted?: string; updated?: string; } + +export interface CreateMlInferencePipelineParameters { + destination_field?: string; + model_id: string; + pipeline_name: string; + source_field: string; +} diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.test.ts new file mode 100644 index 0000000000000..4c88466ba32b7 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.test.ts @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +import { mockHttpValues } from '../../../__mocks__/kea_logic'; + +import { + attachMlInferencePipeline, + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse, +} from './attach_ml_inference_pipeline'; + +describe('AttachMlInferencePipelineApiLogic', () => { + const { http } = mockHttpValues; + beforeEach(() => { + jest.clearAllMocks(); + }); + describe('createMlInferencePipeline', () => { + it('calls the api', async () => { + const response: Promise = Promise.resolve({ + addedToParentPipeline: true, + created: false, + id: 'unit-test', + }); + http.post.mockReturnValue(response); + + const args: AttachMlInferencePipelineApiLogicArgs = { + indexName: 'unit-test-index', + pipelineName: 'unit-test', + }; + const result = await attachMlInferencePipeline(args); + expect(http.post).toHaveBeenCalledWith( + '/internal/enterprise_search/indices/unit-test-index/ml_inference/pipeline_processors/attach', + { + body: '{"pipeline_name":"unit-test"}', + } + ); + expect(result).toEqual({ + addedToParentPipeline: true, + created: false, + id: args.pipelineName, + }); + }); + }); +}); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.ts new file mode 100644 index 0000000000000..433c41a75ea0f --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/attach_ml_inference_pipeline.ts @@ -0,0 +1,36 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { AttachMlInferencePipelineResponse } from '../../../../../common/types/pipelines'; + +import { createApiLogic } from '../../../shared/api_logic/create_api_logic'; +import { HttpLogic } from '../../../shared/http'; + +export interface AttachMlInferencePipelineApiLogicArgs { + indexName: string; + pipelineName: string; +} + +export type { AttachMlInferencePipelineResponse }; + +export const attachMlInferencePipeline = async ( + args: AttachMlInferencePipelineApiLogicArgs +): Promise => { + const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors/attach`; + const params = { + pipeline_name: args.pipelineName, + }; + + return await HttpLogic.values.http.post(route, { + body: JSON.stringify(params), + }); +}; + +export const AttachMlInferencePipelineApiLogic = createApiLogic( + ['attach_ml_inference_pipeline_api_logic'], + attachMlInferencePipeline +); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/create_ml_inference_pipeline.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.test.ts similarity index 100% rename from x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/create_ml_inference_pipeline.test.ts rename to x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.test.ts diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/create_ml_inference_pipeline.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts similarity index 89% rename from x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/create_ml_inference_pipeline.ts rename to x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts index ee5e7dd1c4295..78f08c4bc0ee8 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/ml_models/create_ml_inference_pipeline.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/create_ml_inference_pipeline.ts @@ -4,6 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ +import { CreateMlInferencePipelineParameters } from '../../../../../common/types/pipelines'; import { createApiLogic } from '../../../shared/api_logic/create_api_logic'; import { HttpLogic } from '../../../shared/http'; @@ -23,7 +24,7 @@ export const createMlInferencePipeline = async ( args: CreateMlInferencePipelineApiLogicArgs ): Promise => { const route = `/internal/enterprise_search/indices/${args.indexName}/ml_inference/pipeline_processors`; - const params = { + const params: CreateMlInferencePipelineParameters = { destination_field: args.destinationField, model_id: args.modelId, pipeline_name: args.pipelineName, diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipeline_processors.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipeline_processors.ts index 85f481b513525..2d881a0463bb7 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipeline_processors.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipeline_processors.ts @@ -9,10 +9,18 @@ import { InferencePipeline } from '../../../../../common/types/pipelines'; import { createApiLogic } from '../../../shared/api_logic/create_api_logic'; import { HttpLogic } from '../../../shared/http'; -export const fetchMlInferencePipelineProcessors = async ({ indexName }: { indexName: string }) => { +export interface FetchMlInferencePipelineProcessorsApiLogicArgs { + indexName: string; +} + +export type FetchMlInferencePipelineProcessorsResponse = InferencePipeline[]; + +export const fetchMlInferencePipelineProcessors = async ({ + indexName, +}: FetchMlInferencePipelineProcessorsApiLogicArgs) => { const route = `/internal/enterprise_search/indices/${indexName}/ml_inference/pipeline_processors`; - return await HttpLogic.values.http.get(route); + return await HttpLogic.values.http.get(route); }; export const FetchMlInferencePipelineProcessorsApiLogic = createApiLogic( diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipelines.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipelines.ts new file mode 100644 index 0000000000000..d5df97d259fda --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/api/pipelines/fetch_ml_inference_pipelines.ts @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import { MlInferencePipeline } from '../../../../../common/types/pipelines'; +import { createApiLogic } from '../../../shared/api_logic/create_api_logic'; +import { HttpLogic } from '../../../shared/http'; + +export type FetchMlInferencePipelinesArgs = undefined; +export type FetchMlInferencePipelinesResponse = Record; + +export const fetchMlInferencePipelines = async () => { + const route = '/internal/enterprise_search/pipelines/ml_inference'; + + return await HttpLogic.values.http.get(route); +}; + +export const FetchMlInferencePipelinesApiLogic = createApiLogic( + ['fetch_ml_inference_pipelines_api_logic'], + fetchMlInferencePipelines +); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_ml_inference_pipeline_modal.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_ml_inference_pipeline_modal.tsx index edbf18f8b009c..cc0cc3eb8f954 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_ml_inference_pipeline_modal.tsx +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/add_ml_inference_pipeline_modal.tsx @@ -92,7 +92,7 @@ const AddProcessorContent: React.FC = ({ onClo ); } - if (supportedMLModels === undefined || supportedMLModels?.length === 0) { + if (supportedMLModels.length === 0) { return ; } return ( @@ -188,8 +188,10 @@ const ModalFooter: React.FC { const { addInferencePipelineModal: modal, isPipelineDataValid } = useValues(MLInferenceLogic); - const { createPipeline, setAddInferencePipelineStep } = useActions(MLInferenceLogic); + const { attachPipeline, createPipeline, setAddInferencePipelineStep } = + useActions(MLInferenceLogic); + const attachExistingPipeline = Boolean(modal.configuration.existingPipeline); let nextStep: AddInferencePipelineSteps | undefined; let previousStep: AddInferencePipelineSteps | undefined; switch (modal.step) { @@ -239,6 +241,21 @@ const ModalFooter: React.FC {CONTINUE_BUTTON_LABEL} + ) : attachExistingPipeline ? ( + + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.transforms.addInferencePipelineModal.footer.attach', + { + defaultMessage: 'Attach', + } + )} + ) : ( ( { const { addInferencePipelineModal: { configuration }, formErrors, + existingInferencePipelines, supportedMLModels, sourceFields, } = useValues(MLInferenceLogic); - const { setInferencePipelineConfiguration } = useActions(MLInferenceLogic); + const { selectExistingPipeline, setInferencePipelineConfiguration } = + useActions(MLInferenceLogic); const { ingestionMethod } = useValues(IndexViewLogic); const { destinationField, modelID, pipelineName, sourceField } = configuration; - const models = supportedMLModels ?? []; const nameError = formErrors.pipelineName !== undefined && pipelineName.length > 0; const emptySourceFields = (sourceFields?.length ?? 0) === 0; @@ -76,12 +92,30 @@ export const ConfigurePipeline: React.FC = () => { ), value: MODEL_SELECT_PLACEHOLDER_VALUE, }, - ...models.map((model) => ({ + ...supportedMLModels.map((model) => ({ dropdownDisplay: , inputDisplay: model.model_id, value: model.model_id, })), ]; + const pipelineOptions: Array> = [ + { + disabled: true, + inputDisplay: i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.placeholder', + { defaultMessage: 'Select one' } + ), + value: PIPELINE_SELECT_PLACEHOLDER_VALUE, + }, + ...(existingInferencePipelines?.map((pipeline) => ({ + disabled: pipeline.disabled, + dropdownDisplay: , + inputDisplay: pipeline.pipelineName, + value: pipeline.pipelineName, + })) ?? []), + ]; + + const inputsDisabled = configuration.existingPipeline !== false; return ( <> @@ -106,45 +140,107 @@ export const ConfigurePipeline: React.FC = () => { - - + + + + setInferencePipelineConfiguration({ + ...EMPTY_PIPELINE_CONFIGURATION, + existingPipeline: e.target.value === 'true', + }) + } + /> + + + + {configuration.existingPipeline === true ? ( + + 0 ? pipelineName : PIPELINE_SELECT_PLACEHOLDER_VALUE + } + options={pipelineOptions} + onChange={(value) => selectExistingPipeline(value)} + /> + + ) : ( + + + setInferencePipelineConfiguration({ + ...configuration, + pipelineName: e.target.value, + }) + } + /> + )} - value={pipelineName} - onChange={(e) => - setInferencePipelineConfiguration({ - ...configuration, - pipelineName: e.target.value, - }) - } - /> - + + { data-telemetry-id={`entSearchContent-${ingestionMethod}-pipelines-configureInferencePipeline-selectTrainedModel`} fullWidth hasDividers + disabled={inputsDisabled} itemLayoutAlign="top" onChange={(value) => setInferencePipelineConfiguration({ @@ -185,6 +282,7 @@ export const ConfigurePipeline: React.FC = () => { > { > diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts index c605009d7eb0d..4224c150af904 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.test.ts @@ -7,20 +7,27 @@ import { LogicMounter } from '../../../../../__mocks__/kea_logic'; -import { HttpError, Status } from '../../../../../../../common/types/api'; +import { HttpResponse } from '@kbn/core/public'; + +import { ErrorResponse, HttpError, Status } from '../../../../../../../common/types/api'; +import { TrainedModelState } from '../../../../../../../common/types/pipelines'; import { MappingsApiLogic } from '../../../../api/mappings/mappings_logic'; -import { CreateMlInferencePipelineApiLogic } from '../../../../api/ml_models/create_ml_inference_pipeline'; import { MLModelsApiLogic } from '../../../../api/ml_models/ml_models_logic'; +import { AttachMlInferencePipelineApiLogic } from '../../../../api/pipelines/attach_ml_inference_pipeline'; +import { CreateMlInferencePipelineApiLogic } from '../../../../api/pipelines/create_ml_inference_pipeline'; +import { FetchMlInferencePipelineProcessorsApiLogic } from '../../../../api/pipelines/fetch_ml_inference_pipeline_processors'; +import { FetchMlInferencePipelinesApiLogic } from '../../../../api/pipelines/fetch_ml_inference_pipelines'; import { SimulateMlInterfacePipelineApiLogic } from '../../../../api/pipelines/simulate_ml_inference_pipeline_processors'; import { MLInferenceLogic, EMPTY_PIPELINE_CONFIGURATION, AddInferencePipelineSteps, + MLInferenceProcessorsValues, } from './ml_inference_logic'; -const DEFAULT_VALUES = { +const DEFAULT_VALUES: MLInferenceProcessorsValues = { addInferencePipelineModal: { configuration: { ...EMPTY_PIPELINE_CONFIGURATION, @@ -46,6 +53,7 @@ const DEFAULT_VALUES = { step: AddInferencePipelineSteps.Configuration, }, createErrors: [], + existingInferencePipelines: [], formErrors: { modelID: 'Field is required.', pipelineName: 'Field is required.', @@ -57,6 +65,8 @@ const DEFAULT_VALUES = { mappingData: undefined, mappingStatus: 0, mlInferencePipeline: undefined, + mlInferencePipelineProcessors: undefined, + mlInferencePipelinesData: undefined, mlModelsData: undefined, mlModelsStatus: 0, simulatePipelineData: undefined, @@ -64,7 +74,7 @@ const DEFAULT_VALUES = { simulatePipelineResult: undefined, simulatePipelineStatus: 0, sourceFields: undefined, - supportedMLModels: undefined, + supportedMLModels: [], }; describe('MlInferenceLogic', () => { @@ -77,13 +87,25 @@ describe('MlInferenceLogic', () => { const { mount: mountCreateMlInferencePipelineApiLogic } = new LogicMounter( CreateMlInferencePipelineApiLogic ); + const { mount: mountAttachMlInferencePipelineApiLogic } = new LogicMounter( + AttachMlInferencePipelineApiLogic + ); + const { mount: mountFetchMlInferencePipelineProcessorsApiLogic } = new LogicMounter( + FetchMlInferencePipelineProcessorsApiLogic + ); + const { mount: mountFetchMlInferencePipelinesApiLogic } = new LogicMounter( + FetchMlInferencePipelinesApiLogic + ); beforeEach(() => { jest.clearAllMocks(); mountMappingApiLogic(); mountMLModelsApiLogic(); + mountFetchMlInferencePipelineProcessorsApiLogic(); + mountFetchMlInferencePipelinesApiLogic(); mountSimulateMlInterfacePipelineApiLogic(); mountCreateMlInferencePipelineApiLogic(); + mountAttachMlInferencePipelineApiLogic(); mount(); }); @@ -110,6 +132,70 @@ describe('MlInferenceLogic', () => { }); }); }); + describe('attachApiError', () => { + it('updates create errors', () => { + MLInferenceLogic.actions.attachApiError({ + body: { + error: '', + message: 'this is an error', + statusCode: 500, + }, + } as HttpResponse); + + expect(MLInferenceLogic.values.createErrors).toEqual(['this is an error']); + }); + }); + describe('createApiError', () => { + it('updates create errors', () => { + MLInferenceLogic.actions.createApiError({ + body: { + error: '', + message: 'this is an error', + statusCode: 500, + }, + } as HttpResponse); + + expect(MLInferenceLogic.values.createErrors).toEqual(['this is an error']); + }); + }); + describe('makeAttachPipelineRequest', () => { + it('clears existing errors', () => { + MLInferenceLogic.actions.attachApiError({ + body: { + error: '', + message: 'this is an error', + statusCode: 500, + }, + } as HttpResponse); + + expect(MLInferenceLogic.values.createErrors).not.toHaveLength(0); + MLInferenceLogic.actions.makeAttachPipelineRequest({ + indexName: 'test', + pipelineName: 'unit-test', + }); + expect(MLInferenceLogic.values.createErrors).toHaveLength(0); + }); + }); + describe('makeCreatePipelineRequest', () => { + it('clears existing errors', () => { + MLInferenceLogic.actions.createApiError({ + body: { + error: '', + message: 'this is an error', + statusCode: 500, + }, + } as HttpResponse); + + expect(MLInferenceLogic.values.createErrors).not.toHaveLength(0); + MLInferenceLogic.actions.makeCreatePipelineRequest({ + indexName: 'test', + pipelineName: 'unit-test', + modelId: 'test-model', + sourceField: 'body', + }); + expect(MLInferenceLogic.values.createErrors).toHaveLength(0); + }); + }); }); describe('selectors', () => { @@ -162,6 +248,220 @@ describe('MlInferenceLogic', () => { expect(MLInferenceLogic.values.simulatePipelineResult).toEqual(simulateResponse); }); }); + describe('existingInferencePipelines', () => { + beforeEach(() => { + MappingsApiLogic.actions.apiSuccess({ + mappings: { + properties: { + body: { + type: 'text', + }, + }, + }, + }); + }); + it('returns empty list when there is not existing pipelines available', () => { + expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([]); + }); + it('returns existing pipeline option', () => { + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': { + processors: [ + { + inference: { + field_map: { + body: 'text_field', + }, + model_id: 'test-model', + target_field: 'ml.inference.test-field', + }, + }, + ], + version: 1, + }, + }); + + expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([ + { + destinationField: 'test-field', + disabled: false, + pipelineName: 'unit-test', + modelType: '', + modelId: 'test-model', + sourceField: 'body', + }, + ]); + }); + it('returns disabled pipeline option if missing source field', () => { + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': { + processors: [ + { + inference: { + field_map: { + body_content: 'text_field', + }, + model_id: 'test-model', + target_field: 'ml.inference.test-field', + }, + }, + ], + version: 1, + }, + }); + + expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([ + { + destinationField: 'test-field', + disabled: true, + disabledReason: expect.any(String), + pipelineName: 'unit-test', + modelType: '', + modelId: 'test-model', + sourceField: 'body_content', + }, + ]); + }); + it('returns disabled pipeline option if model is redacted', () => { + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': { + processors: [ + { + inference: { + field_map: { + body: 'text_field', + }, + model_id: '', + target_field: 'ml.inference.test-field', + }, + }, + ], + version: 1, + }, + }); + + expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([ + { + destinationField: 'test-field', + disabled: true, + disabledReason: expect.any(String), + pipelineName: 'unit-test', + modelType: '', + modelId: '', + sourceField: 'body', + }, + ]); + }); + it('returns disabled pipeline option if pipeline already attached', () => { + FetchMlInferencePipelineProcessorsApiLogic.actions.apiSuccess([ + { + modelId: 'test-model', + modelState: TrainedModelState.Started, + pipelineName: 'unit-test', + pipelineReferences: ['test@ml-inference'], + types: ['ner', 'pytorch'], + }, + ]); + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': { + processors: [ + { + inference: { + field_map: { + body: 'text_field', + }, + model_id: 'test-model', + target_field: 'ml.inference.test-field', + }, + }, + ], + version: 1, + }, + }); + + expect(MLInferenceLogic.values.existingInferencePipelines).toEqual([ + { + destinationField: 'test-field', + disabled: true, + disabledReason: expect.any(String), + pipelineName: 'unit-test', + modelType: '', + modelId: 'test-model', + sourceField: 'body', + }, + ]); + }); + }); + describe('mlInferencePipeline', () => { + it('returns undefined when configuration is invalid', () => { + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + destinationField: '', + modelID: '', + pipelineName: 'unit-test', + sourceField: '', + }); + + expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); + }); + it('generates inference pipeline', () => { + MLModelsApiLogic.actions.apiSuccess([ + { + inference_config: { + text_classification: { + classification_labels: ['one', 'two'], + tokenization: { + bert: {}, + }, + }, + }, + input: { + field_names: ['text_field'], + }, + model_id: 'test-model', + model_type: 'pytorch', + tags: [], + version: '1.0.0', + }, + ]); + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + destinationField: '', + modelID: 'test-model', + pipelineName: 'unit-test', + sourceField: 'body', + }); + + expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); + }); + it('returns undefined when existing pipeline not yet selected', () => { + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + existingPipeline: true, + destinationField: '', + modelID: '', + pipelineName: '', + sourceField: '', + }); + expect(MLInferenceLogic.values.mlInferencePipeline).toBeUndefined(); + }); + it('return existing pipeline when selected', () => { + const existingPipeline = { + description: 'this is a test', + processors: [], + version: 1, + }; + FetchMlInferencePipelinesApiLogic.actions.apiSuccess({ + 'unit-test': existingPipeline, + }); + MLInferenceLogic.actions.setInferencePipelineConfiguration({ + existingPipeline: true, + destinationField: '', + modelID: '', + pipelineName: 'unit-test', + sourceField: '', + }); + expect(MLInferenceLogic.values.mlInferencePipeline).not.toBeUndefined(); + expect(MLInferenceLogic.values.mlInferencePipeline).toEqual(existingPipeline); + }); + }); }); describe('listeners', () => { diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts index f4a968da1c2a1..fcdad4f66d141 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/ml_inference_logic.ts @@ -15,6 +15,8 @@ import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_ import { formatPipelineName, generateMlInferencePipelineBody, + getMlModelTypesForModelConfig, + parseMlInferenceParametersFromPipeline, } from '../../../../../../../common/ml_inference_pipeline'; import { Status } from '../../../../../../../common/types/api'; import { MlInferencePipeline } from '../../../../../../../common/types/pipelines'; @@ -30,16 +32,30 @@ import { GetMappingsResponse, MappingsApiLogic, } from '../../../../api/mappings/mappings_logic'; -import { - CreateMlInferencePipelineApiLogic, - CreateMlInferencePipelineApiLogicArgs, - CreateMlInferencePipelineResponse, -} from '../../../../api/ml_models/create_ml_inference_pipeline'; import { GetMlModelsArgs, GetMlModelsResponse, MLModelsApiLogic, } from '../../../../api/ml_models/ml_models_logic'; +import { + AttachMlInferencePipelineApiLogic, + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse, +} from '../../../../api/pipelines/attach_ml_inference_pipeline'; +import { + CreateMlInferencePipelineApiLogic, + CreateMlInferencePipelineApiLogicArgs, + CreateMlInferencePipelineResponse, +} from '../../../../api/pipelines/create_ml_inference_pipeline'; +import { + FetchMlInferencePipelineProcessorsApiLogic, + FetchMlInferencePipelineProcessorsResponse, +} from '../../../../api/pipelines/fetch_ml_inference_pipeline_processors'; +import { + FetchMlInferencePipelinesApiLogic, + FetchMlInferencePipelinesArgs, + FetchMlInferencePipelinesResponse, +} from '../../../../api/pipelines/fetch_ml_inference_pipelines'; import { SimulateMlInterfacePipelineApiLogic, SimulateMlInterfacePipelineArgs, @@ -47,11 +63,20 @@ import { } from '../../../../api/pipelines/simulate_ml_inference_pipeline_processors'; import { isConnectorIndex } from '../../../../utils/indices'; -import { isSupportedMLModel, sortSourceFields } from '../../../shared/ml_inference/utils'; +import { + getMLType, + isSupportedMLModel, + sortSourceFields, +} from '../../../shared/ml_inference/utils'; import { AddInferencePipelineFormErrors, InferencePipelineConfiguration } from './types'; -import { validateInferencePipelineConfiguration } from './utils'; +import { + validateInferencePipelineConfiguration, + EXISTING_PIPELINE_DISABLED_MODEL_REDACTED, + EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD, + EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS, +} from './utils'; export const EMPTY_PIPELINE_CONFIGURATION: InferencePipelineConfiguration = { destinationField: '', @@ -69,7 +94,26 @@ export enum AddInferencePipelineSteps { const API_REQUEST_COMPLETE_STATUSES = [Status.SUCCESS, Status.ERROR]; const DEFAULT_CONNECTOR_FIELDS = ['body', 'title', 'id', 'type', 'url']; +export interface MLInferencePipelineOption { + destinationField: string; + disabled: boolean; + disabledReason?: string; + modelId: string; + modelType: string; + pipelineName: string; + sourceField: string; +} + interface MLInferenceProcessorsActions { + attachApiError: Actions< + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse + >['apiError']; + attachApiSuccess: Actions< + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse + >['apiSuccess']; + attachPipeline: () => void; createApiError: Actions< CreateMlInferencePipelineApiLogicArgs, CreateMlInferencePipelineResponse @@ -79,18 +123,29 @@ interface MLInferenceProcessorsActions { CreateMlInferencePipelineResponse >['apiSuccess']; createPipeline: () => void; + makeAttachPipelineRequest: Actions< + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse + >['makeRequest']; makeCreatePipelineRequest: Actions< CreateMlInferencePipelineApiLogicArgs, CreateMlInferencePipelineResponse >['makeRequest']; makeMLModelsRequest: Actions['makeRequest']; makeMappingRequest: Actions['makeRequest']; + makeMlInferencePipelinesRequest: Actions< + FetchMlInferencePipelinesArgs, + FetchMlInferencePipelinesResponse + >['makeRequest']; makeSimulatePipelineRequest: Actions< SimulateMlInterfacePipelineArgs, SimulateMlInterfacePipelineResponse >['makeRequest']; mappingsApiError: Actions['apiError']; mlModelsApiError: Actions['apiError']; + selectExistingPipeline: (pipelineName: string) => { + pipelineName: string; + }; setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => { step: AddInferencePipelineSteps; }; @@ -120,21 +175,24 @@ export interface AddInferencePipelineModal { step: AddInferencePipelineSteps; } -interface MLInferenceProcessorsValues { +export interface MLInferenceProcessorsValues { addInferencePipelineModal: AddInferencePipelineModal; createErrors: string[]; + existingInferencePipelines: MLInferencePipelineOption[]; formErrors: AddInferencePipelineFormErrors; - index: FetchIndexApiResponse; + index: FetchIndexApiResponse | undefined; isLoading: boolean; isPipelineDataValid: boolean; mappingData: typeof MappingsApiLogic.values.data; mappingStatus: Status; - mlInferencePipeline?: MlInferencePipeline; - mlModelsData: TrainedModelConfigResponse[]; + mlInferencePipeline: MlInferencePipeline | undefined; + mlInferencePipelineProcessors: FetchMlInferencePipelineProcessorsResponse | undefined; + mlInferencePipelinesData: FetchMlInferencePipelinesResponse | undefined; + mlModelsData: TrainedModelConfigResponse[] | undefined; mlModelsStatus: Status; simulatePipelineData: typeof SimulateMlInterfacePipelineApiLogic.values.data; simulatePipelineErrors: string[]; - simulatePipelineResult: IngestSimulateResponse; + simulatePipelineResult: IngestSimulateResponse | undefined; simulatePipelineStatus: Status; sourceFields: string[] | undefined; supportedMLModels: TrainedModelConfigResponse[]; @@ -144,8 +202,10 @@ export const MLInferenceLogic = kea< MakeLogicType >({ actions: { + attachPipeline: true, clearFormErrors: true, createPipeline: true, + selectExistingPipeline: (pipelineName: string) => ({ pipelineName }), setAddInferencePipelineStep: (step: AddInferencePipelineSteps) => ({ step }), setFormErrors: (inputErrors: AddInferencePipelineFormErrors) => ({ inputErrors }), setIndexName: (indexName: string) => ({ indexName }), @@ -160,6 +220,8 @@ export const MLInferenceLogic = kea< }, connect: { actions: [ + FetchMlInferencePipelinesApiLogic, + ['makeRequest as makeMlInferencePipelinesRequest'], MappingsApiLogic, ['makeRequest as makeMappingRequest', 'apiError as mappingsApiError'], MLModelsApiLogic, @@ -176,20 +238,43 @@ export const MLInferenceLogic = kea< 'apiSuccess as createApiSuccess', 'makeRequest as makeCreatePipelineRequest', ], + AttachMlInferencePipelineApiLogic, + [ + 'apiError as attachApiError', + 'apiSuccess as attachApiSuccess', + 'makeRequest as makeAttachPipelineRequest', + ], ], values: [ FetchIndexApiLogic, ['data as index'], + FetchMlInferencePipelinesApiLogic, + ['data as mlInferencePipelinesData'], MappingsApiLogic, ['data as mappingData', 'status as mappingStatus'], MLModelsApiLogic, ['data as mlModelsData', 'status as mlModelsStatus'], SimulateMlInterfacePipelineApiLogic, ['data as simulatePipelineData', 'status as simulatePipelineStatus'], + FetchMlInferencePipelineProcessorsApiLogic, + ['data as mlInferencePipelineProcessors'], ], }, events: {}, listeners: ({ values, actions }) => ({ + attachPipeline: () => { + const { + addInferencePipelineModal: { + configuration: { pipelineName }, + indexName, + }, + } = values; + + actions.makeAttachPipelineRequest({ + indexName, + pipelineName, + }); + }, createPipeline: () => { const { addInferencePipelineModal: { configuration, indexName }, @@ -206,7 +291,21 @@ export const MLInferenceLogic = kea< sourceField: configuration.sourceField, }); }, + selectExistingPipeline: ({ pipelineName }) => { + const pipeline = values.mlInferencePipelinesData?.[pipelineName]; + if (!pipeline) return; + const params = parseMlInferenceParametersFromPipeline(pipelineName, pipeline); + if (params === null) return; + actions.setInferencePipelineConfiguration({ + destinationField: params.destination_field ?? '', + existingPipeline: true, + modelID: params.model_id, + pipelineName, + sourceField: params.source_field, + }); + }, setIndexName: ({ indexName }) => { + actions.makeMlInferencePipelinesRequest(undefined); actions.makeMLModelsRequest(undefined); actions.makeMappingRequest({ indexName }); }, @@ -264,7 +363,9 @@ export const MLInferenceLogic = kea< createErrors: [ [], { + attachApiError: (_, error) => getErrorsFromHttpResponse(error), createApiError: (_, error) => getErrorsFromHttpResponse(error), + makeAttachPipelineRequest: () => [], makeCreatePipelineRequest: () => [], }, ], @@ -297,12 +398,24 @@ export const MLInferenceLogic = kea< selectors.isPipelineDataValid, selectors.addInferencePipelineModal, selectors.mlModelsData, + selectors.mlInferencePipelinesData, ], ( - isPipelineDataValid: boolean, - { configuration }: AddInferencePipelineModal, - models: MLInferenceProcessorsValues['mlModelsData'] + isPipelineDataValid: MLInferenceProcessorsValues['isPipelineDataValid'], + { configuration }: MLInferenceProcessorsValues['addInferencePipelineModal'], + models: MLInferenceProcessorsValues['mlModelsData'], + mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData'] ) => { + if (configuration.existingPipeline) { + if (configuration.pipelineName.length === 0) { + return undefined; + } + const pipeline = mlInferencePipelinesData?.[configuration.pipelineName]; + if (!pipeline) { + return undefined; + } + return pipeline as MlInferencePipeline; + } if (!isPipelineDataValid) return undefined; const model = models?.find((mlModel) => mlModel.model_id === configuration.modelID); if (!model) return undefined; @@ -350,7 +463,69 @@ export const MLInferenceLogic = kea< supportedMLModels: [ () => [selectors.mlModelsData], (mlModelsData: TrainedModelConfigResponse[] | undefined) => { - return mlModelsData?.filter(isSupportedMLModel); + return mlModelsData?.filter(isSupportedMLModel) ?? []; + }, + ], + existingInferencePipelines: [ + () => [ + selectors.mlInferencePipelinesData, + selectors.sourceFields, + selectors.supportedMLModels, + selectors.mlInferencePipelineProcessors, + ], + ( + mlInferencePipelinesData: MLInferenceProcessorsValues['mlInferencePipelinesData'], + sourceFields: MLInferenceProcessorsValues['sourceFields'], + supportedMLModels: MLInferenceProcessorsValues['supportedMLModels'], + mlInferencePipelineProcessors: MLInferenceProcessorsValues['mlInferencePipelineProcessors'] + ) => { + if (!mlInferencePipelinesData) { + return []; + } + const indexProcessorNames = + mlInferencePipelineProcessors?.map((processor) => processor.pipelineName) ?? []; + + const existingPipelines: MLInferencePipelineOption[] = Object.entries( + mlInferencePipelinesData + ) + .map(([pipelineName, pipeline]): MLInferencePipelineOption | undefined => { + if (!pipeline) return undefined; + const pipelineParams = parseMlInferenceParametersFromPipeline(pipelineName, pipeline); + if (!pipelineParams) return undefined; + const { + destination_field: destinationField, + model_id: modelId, + source_field: sourceField, + } = pipelineParams; + + let disabled: boolean = false; + let disabledReason: string | undefined; + if (!(sourceFields?.includes(sourceField) ?? false)) { + disabled = true; + disabledReason = EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD; + } else if (indexProcessorNames.includes(pipelineName)) { + disabled = true; + disabledReason = EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS; + } else if (pipelineParams.model_id.length === 0) { + disabled = true; + disabledReason = EXISTING_PIPELINE_DISABLED_MODEL_REDACTED; + } + const mlModel = supportedMLModels.find((model) => model.model_id === modelId); + const modelType = mlModel ? getMLType(getMlModelTypesForModelConfig(mlModel)) : ''; + + return { + destinationField: destinationField ?? '', + disabled, + disabledReason, + modelId, + modelType, + pipelineName, + sourceField, + }; + }) + .filter((p): p is MLInferencePipelineOption => p !== undefined); + + return existingPipelines; }, ], }), diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/pipeline_select_option.tsx b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/pipeline_select_option.tsx new file mode 100644 index 0000000000000..f782c827a9728 --- /dev/null +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/pipeline_select_option.tsx @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +import React from 'react'; + +import { EuiBadge, EuiFlexGroup, EuiFlexItem, EuiIcon, EuiTextColor, EuiTitle } from '@elastic/eui'; +import { i18n } from '@kbn/i18n'; + +import { MLInferencePipelineOption } from './ml_inference_logic'; +import { EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD } from './utils'; + +export interface PipelineSelectOptionProps { + pipeline: MLInferencePipelineOption; +} + +const REDACTED_MODE_ID_DISPLAY = i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.redactedModel', + { + defaultMessage: 'Trained model not available in this space', + } +); + +export const PipelineSelectOption: React.FC = ({ pipeline }) => { + const modelIdDisplay = pipeline.modelId.length > 0 ? pipeline.modelId : REDACTED_MODE_ID_DISPLAY; + return ( + + {pipeline.disabled && ( + + + + + + + + {pipeline.disabledReason ?? EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD} + + + + + )} + + +

{pipeline.pipelineName}

+
+
+ + + + {pipeline.disabled ? ( + modelIdDisplay + ) : ( + {modelIdDisplay} + )} + + {pipeline.modelType.length > 0 && ( + + + {pipeline.modelType} + + + )} + + + + + + + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.sourceField', + { defaultMessage: 'Source field' } + )} + + + {pipeline.sourceField} + + + + + + + {i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.destinationField', + { defaultMessage: 'Destination field' } + )} + + + {pipeline.destinationField} + + +
+ ); +}; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts index 29ad5e9193fdb..9ad288c4b84f5 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/types.ts @@ -7,6 +7,7 @@ export interface InferencePipelineConfiguration { destinationField: string; + existingPipeline?: boolean; modelID: string; pipelineName: string; sourceField: string; diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts index 8db23f5deb7d6..8ad94e5f92da4 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/ml_inference/utils.ts @@ -31,6 +31,12 @@ export const validateInferencePipelineConfiguration = ( config: InferencePipelineConfiguration ): AddInferencePipelineFormErrors => { const errors: AddInferencePipelineFormErrors = {}; + if (config.existingPipeline === true) { + if (config.pipelineName.length === 0) { + errors.pipelineName = FIELD_REQUIRED_ERROR; + } + return errors; + } if (config.pipelineName.trim().length === 0) { errors.pipelineName = FIELD_REQUIRED_ERROR; } else if (!isValidPipelineName(config.pipelineName)) { @@ -45,3 +51,27 @@ export const validateInferencePipelineConfiguration = ( return errors; }; + +export const EXISTING_PIPELINE_DISABLED_MISSING_SOURCE_FIELD = i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledSourceFieldDescription', + { + defaultMessage: + 'This pipeline cannot be selected because the source field does not exist on this index.', + } +); + +export const EXISTING_PIPELINE_DISABLED_PIPELINE_EXISTS = i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledPipelineExistsDescription', + { + defaultMessage: 'This pipeline cannot be selected because it is already attached.', + } +); + +// TODO: removed when we support attaching pipelines with unavailable models +export const EXISTING_PIPELINE_DISABLED_MODEL_REDACTED = i18n.translate( + 'xpack.enterpriseSearch.content.indices.pipelines.addInferencePipelineModal.steps.configure.existingPipeline.disabledModelRedactedDescription', + { + defaultMessage: + 'This pipeline cannot be selected because it uses a trained model not available in this Kibana space.', + } +); diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/pipelines_logic.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/pipelines_logic.ts index dca18863cde02..f4c9aad591c72 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/pipelines_logic.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/search_index/pipelines/pipelines_logic.ts @@ -47,12 +47,21 @@ import { FetchIndexApiParams, FetchIndexApiResponse, } from '../../../api/index/fetch_index_api_logic'; -import { CreateMlInferencePipelineApiLogic } from '../../../api/ml_models/create_ml_inference_pipeline'; import { DeleteMlInferencePipelineApiLogic, DeleteMlInferencePipelineApiLogicArgs, DeleteMlInferencePipelineResponse, } from '../../../api/ml_models/delete_ml_inference_pipeline'; +import { + AttachMlInferencePipelineApiLogic, + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse, +} from '../../../api/pipelines/attach_ml_inference_pipeline'; +import { + CreateMlInferencePipelineApiLogic, + CreateMlInferencePipelineApiLogicArgs, + CreateMlInferencePipelineResponse, +} from '../../../api/pipelines/create_ml_inference_pipeline'; import { FetchMlInferencePipelineProcessorsApiLogic } from '../../../api/pipelines/fetch_ml_inference_pipeline_processors'; import { isApiIndex, isConnectorIndex, isCrawlerIndex } from '../../../utils/indices'; @@ -60,6 +69,10 @@ type PipelinesActions = Pick< Actions, 'apiError' | 'apiSuccess' | 'makeRequest' > & { + attachMlInferencePipelineSuccess: Actions< + AttachMlInferencePipelineApiLogicArgs, + AttachMlInferencePipelineResponse + >['apiSuccess']; closeAddMlInferencePipelineModal: () => void; closeModal: () => void; createCustomPipeline: Actions< @@ -74,6 +87,10 @@ type PipelinesActions = Pick< CreateCustomPipelineApiLogicArgs, CreateCustomPipelineApiLogicResponse >['apiSuccess']; + createMlInferencePipelineSuccess: Actions< + CreateMlInferencePipelineApiLogicArgs, + CreateMlInferencePipelineResponse + >['apiSuccess']; deleteMlPipeline: Actions< DeleteMlInferencePipelineApiLogicArgs, DeleteMlInferencePipelineResponse @@ -153,6 +170,8 @@ export const PipelinesLogic = kea { + // Re-fetch processors to ensure we display newly added ml processor + actions.fetchMlInferenceProcessors({ indexName: values.index.name }); + // Needed to ensure correct JSON is available in the JSON configurations tab + actions.fetchCustomPipeline({ indexName: values.index.name }); + }, closeModal: () => actions.setPipelineState( isConnectorIndex(values.index) || isCrawlerIndex(values.index) @@ -287,6 +312,7 @@ export const PipelinesLogic = kea false, closeAddMlInferencePipelineModal: () => false, createMlInferencePipelineSuccess: () => false, openAddMlInferencePipelineModal: () => true, diff --git a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts index 0b2955cb7f30e..f24fe059cc5d0 100644 --- a/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts +++ b/x-pack/plugins/enterprise_search/public/applications/enterprise_search_content/components/shared/ml_inference/utils.ts @@ -8,14 +8,9 @@ import { i18n } from '@kbn/i18n'; import { TrainedModelConfigResponse } from '@kbn/ml-plugin/common/types/trained_models'; -export const NLP_CONFIG_KEYS = [ - 'fill_mask', - 'ner', - 'text_classification', - 'text_embedding', - 'question_answering', - 'zero_shot_classification', -]; +import { SUPPORTED_PYTORCH_TASKS } from '../../../../../../common/ml_inference_pipeline'; + +export const NLP_CONFIG_KEYS: string[] = Object.values(SUPPORTED_PYTORCH_TASKS); export const RECOMMENDED_FIELDS = ['body', 'body_content', 'title']; export const NLP_DISPLAY_TITLES: Record = {