diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/inference_base.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/inference_base.ts index 6f9305ae92f6d..82fbab3042a94 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/inference_base.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/inference_base.ts @@ -16,6 +16,7 @@ import { ES_FIELD_TYPES } from '@kbn/field-types'; import type { MLHttpFetchError } from '@kbn/ml-error-utils'; import type { trainedModelsApiProvider } from '../../../services/ml_api_service/trained_models'; import { getInferenceInfoComponent } from './inference_info'; +import type { ITelemetryClient } from '../../../services/telemetry/types'; export type InferenceType = | SupportedPytorchTasksType @@ -84,7 +85,8 @@ export abstract class InferenceBase { protected readonly trainedModelsApi: ReturnType, protected readonly model: estypes.MlTrainedModelConfig, protected readonly inputType: INPUT_TYPE, - protected readonly deploymentId: string + protected readonly deploymentId: string, + private readonly telemetryClient: ITelemetryClient ) { this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD; this.inputField$.next(this.modelInputField); @@ -317,9 +319,14 @@ export abstract class InferenceBase { this.inferenceResult$.next([processedResponse]); this.setFinished(); + this.trackModelTested('success'); + return [processedResponse]; } catch (error) { this.setFinishedWithErrors(error); + + this.trackModelTested('failure'); + throw error; } } @@ -336,9 +343,15 @@ export abstract class InferenceBase { const processedResponse = docs.map((d) => processResponse(this.getDocFromResponse(d))); this.inferenceResult$.next(processedResponse); this.setFinished(); + + this.trackModelTested('success'); + return processedResponse; } catch (error) { this.setFinishedWithErrors(error); + + this.trackModelTested('failure'); + throw error; } } @@ -392,4 +405,13 @@ export abstract class InferenceBase { } return doc; } + + private trackModelTested(result: 'success' | 'failure') { + this.telemetryClient.trackTrainedModelsModelTested({ + model_id: this.model.model_id, + model_type: this.model.model_type, + task_type: this.inferenceType, + result, + }); + } } diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/ner/ner_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/ner/ner_inference.ts index d49db2d03fe01..c3bb3d7f0c250 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/ner/ner_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/ner/ner_inference.ts @@ -13,6 +13,7 @@ import { InferenceBase, INPUT_TYPE } from '../inference_base'; import type { InferResponse } from '../inference_base'; import { getGeneralInputComponent } from '../text_input'; import { getNerOutputComponent } from './ner_output'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export type FormattedNerResponse = Array<{ value: string; @@ -37,9 +38,10 @@ export class NerInference extends InferenceBase { trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize(); } diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts index 2bb5593a45a13..884ce85e9259c 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/question_answering/question_answering_inference.ts @@ -15,6 +15,7 @@ import type { InferResponse, INPUT_TYPE } from '../inference_base'; import { getQuestionAnsweringInput } from './question_answering_input'; import { getQuestionAnsweringOutputComponent } from './question_answering_output'; import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export interface RawQuestionAnsweringResponse { inference_results: Array<{ @@ -63,9 +64,10 @@ export class QuestionAnsweringInference extends InferenceBase, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize( [this.questionText$.pipe(map((questionText) => questionText !== ''))], diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts index 156062a77389d..efc1de14af48b 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/fill_mask_inference.ts @@ -15,6 +15,7 @@ import { processResponse, processInferenceResult } from './common'; import { getGeneralInputComponent } from '../text_input'; import { getFillMaskOutputComponent } from './fill_mask_output'; import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; const DEFAULT_MASK_TOKEN = '[MASK]'; @@ -36,9 +37,10 @@ export class FillMaskInference extends InferenceBase trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); const maskToken = model.inference_config?.[this.inferenceType]?.mask_token; if (maskToken) { this.maskToken = maskToken; diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts index 4f0548088c836..2e04be860dc98 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/lang_ident_inference.ts @@ -14,6 +14,7 @@ import { getGeneralInputComponent } from '../text_input'; import { getLangIdentOutputComponent } from './lang_ident_output'; import type { TextClassificationResponse, RawTextClassificationResponse } from './common'; import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export class LangIdentInference extends InferenceBase { protected inferenceType: InferenceType = 'classification'; @@ -32,9 +33,10 @@ export class LangIdentInference extends InferenceBase, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize(); } diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts index 5c2641ae058b8..44f7d75aefe14 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/text_classification_inference.ts @@ -14,6 +14,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from ' import { getGeneralInputComponent } from '../text_input'; import { getTextClassificationOutputComponent } from './text_classification_output'; import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export class TextClassificationInference extends InferenceBase { protected inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION; @@ -31,9 +32,10 @@ export class TextClassificationInference extends InferenceBase, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize(); } diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts index a06af98081e80..9aa012d576edb 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_classification/zero_shot_classification_inference.ts @@ -18,6 +18,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from ' import { getZeroShotClassificationInput } from './zero_shot_classification_input'; import { getTextClassificationOutputComponent } from './text_classification_output'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export class ZeroShotClassificationInference extends InferenceBase { protected inferenceType = SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION; @@ -39,9 +40,10 @@ export class ZeroShotClassificationInference extends InferenceBase, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize( [this.labelsText$.pipe(map((labelsText) => labelsText !== ''))], diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts index 9a1081da7272b..f958749826252 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_embedding/text_embedding_inference.ts @@ -13,6 +13,7 @@ import type { InferResponse } from '../inference_base'; import { getGeneralInputComponent } from '../text_input'; import { getTextEmbeddingOutputComponent } from './text_embedding_output'; import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export interface RawTextEmbeddingResponse { inference_results: Array<{ predicted_value: number[] }>; @@ -43,9 +44,10 @@ export class TextEmbeddingInference extends InferenceBase trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize(); } diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_expansion/text_expansion_inference.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_expansion/text_expansion_inference.ts index d1e322e5d5d32..0d56178bcb4f8 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_expansion/text_expansion_inference.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/models/text_expansion/text_expansion_inference.ts @@ -15,6 +15,7 @@ import type { INPUT_TYPE } from '../inference_base'; import { InferenceBase, type InferResponse } from '../inference_base'; import { getTextExpansionOutputComponent } from './text_expansion_output'; import { getTextExpansionInput } from './text_expansion_input'; +import type { ITelemetryClient } from '../../../../services/telemetry/types'; export interface TextExpansionPair { token: string; @@ -53,9 +54,10 @@ export class TextExpansionInference extends InferenceBase trainedModelsApi: ReturnType, model: estypes.MlTrainedModelConfig, inputType: INPUT_TYPE, - deploymentId: string + deploymentId: string, + telemetryClient: ITelemetryClient ) { - super(trainedModelsApi, model, inputType, deploymentId); + super(trainedModelsApi, model, inputType, deploymentId, telemetryClient); this.initialize( [this.queryText$.pipe(map((questionText) => questionText !== ''))], diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/selected_model.tsx b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/selected_model.tsx index 598bd206608c6..b5c829b22d020 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/selected_model.tsx +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/test_models/selected_model.tsx @@ -35,6 +35,7 @@ import { isMlIngestInferenceProcessor, isMlInferencePipelineInferenceConfig, } from '../create_pipeline_for_model/get_inference_properties_from_pipeline_config'; +import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context'; interface Props { model: estypes.MlTrainedModelConfig; @@ -54,6 +55,7 @@ export const SelectedModel: FC = ({ setCurrentContext, }) => { const { trainedModels } = useMlApi(); + const { telemetryClient } = useMlTelemetryClient(); const inferrer = useMemo(() => { const taskType = Object.keys(model.inference_config ?? {})[0]; @@ -65,14 +67,21 @@ export const SelectedModel: FC = ({ if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) { switch (taskType) { case SUPPORTED_PYTORCH_TASKS.NER: - tempInferrer = new NerInference(trainedModels, model, inputType, deploymentId); + tempInferrer = new NerInference( + trainedModels, + model, + inputType, + deploymentId, + telemetryClient + ); break; case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION: tempInferrer = new TextClassificationInference( trainedModels, model, inputType, - deploymentId + deploymentId, + telemetryClient ); break; case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION: @@ -80,7 +89,8 @@ export const SelectedModel: FC = ({ trainedModels, model, inputType, - deploymentId + deploymentId, + telemetryClient ); if (pipelineConfigValues) { const { labels, multi_label: multiLabel } = pipelineConfigValues; @@ -91,30 +101,55 @@ export const SelectedModel: FC = ({ } break; case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING: - tempInferrer = new TextEmbeddingInference(trainedModels, model, inputType, deploymentId); + tempInferrer = new TextEmbeddingInference( + trainedModels, + model, + inputType, + deploymentId, + telemetryClient + ); break; case SUPPORTED_PYTORCH_TASKS.FILL_MASK: - tempInferrer = new FillMaskInference(trainedModels, model, inputType, deploymentId); + tempInferrer = new FillMaskInference( + trainedModels, + model, + inputType, + deploymentId, + telemetryClient + ); break; case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING: tempInferrer = new QuestionAnsweringInference( trainedModels, model, inputType, - deploymentId + deploymentId, + telemetryClient ); if (pipelineConfigValues?.question) { tempInferrer.setQuestionText(pipelineConfigValues.question); } break; case SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION: - tempInferrer = new TextExpansionInference(trainedModels, model, inputType, deploymentId); + tempInferrer = new TextExpansionInference( + trainedModels, + model, + inputType, + deploymentId, + telemetryClient + ); break; default: break; } } else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) { - tempInferrer = new LangIdentInference(trainedModels, model, inputType, deploymentId); + tempInferrer = new LangIdentInference( + trainedModels, + model, + inputType, + deploymentId, + telemetryClient + ); } if (tempInferrer) { if (pipelineConfigValues) { diff --git a/x-pack/platform/plugins/shared/ml/public/application/model_management/trained_models_service.test.ts b/x-pack/platform/plugins/shared/ml/public/application/model_management/trained_models_service.test.ts index 3112a050a3041..1549087f9579e 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/model_management/trained_models_service.test.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/model_management/trained_models_service.test.ts @@ -88,7 +88,7 @@ describe('TrainedModelsService', () => { mockTelemetryService = { trackTrainedModelsDeploymentCreated: jest.fn(), - }; + } as unknown as jest.Mocked; mockTrainedModelsApiService = { getTrainedModelsList: jest.fn(), diff --git a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/events.ts b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/events.ts index 45a07b12f22dc..820ee6bc2e285 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/events.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/events.ts @@ -6,6 +6,7 @@ */ import type { SchemaObject } from '@elastic/ebt'; +import type { TrainedModelsModelTestedEbtProps } from './types'; import { TrainedModelsTelemetryEventTypes, type TrainedModelsDeploymentEbtProps, @@ -66,11 +67,47 @@ const trainedModelsDeploymentSchema: SchemaObject['properties'] = + { + model_id: { + type: 'keyword', + _meta: { + description: 'The ID of the trained model', + }, + }, + model_type: { + type: 'keyword', + _meta: { + description: 'The type of the trained model', + optional: true, + }, + }, + task_type: { + type: 'keyword', + _meta: { + description: 'The type of the task', + optional: true, + }, + }, + result: { + type: 'keyword', + _meta: { + description: 'The result of the task', + }, + }, + }; + const trainedModelsDeploymentCreatedEventType: TrainedModelsTelemetryEvent = { eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED, schema: trainedModelsDeploymentSchema, }; +const trainedModelsModelTestedEventType: TrainedModelsTelemetryEvent = { + eventType: TrainedModelsTelemetryEventTypes.MODEL_TESTED, + schema: trainedModelsModelTestedSchema, +}; + export const trainedModelsEbtEvents = { trainedModelsDeploymentCreatedEventType, + trainedModelsModelTestedEventType, }; diff --git a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_client.ts b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_client.ts index 2b5cf03c33ffc..459dee94fac01 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_client.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_client.ts @@ -6,7 +6,11 @@ */ import type { AnalyticsServiceSetup } from '@kbn/core-analytics-browser'; -import type { ITelemetryClient, TrainedModelsDeploymentEbtProps } from './types'; +import type { + ITelemetryClient, + TrainedModelsDeploymentEbtProps, + TrainedModelsModelTestedEbtProps, +} from './types'; import { TrainedModelsTelemetryEventTypes } from './types'; export class TelemetryClient implements ITelemetryClient { @@ -15,4 +19,8 @@ export class TelemetryClient implements ITelemetryClient { public trackTrainedModelsDeploymentCreated = (eventProps: TrainedModelsDeploymentEbtProps) => { this.analytics.reportEvent(TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED, eventProps); }; + + public trackTrainedModelsModelTested = (eventProps: TrainedModelsModelTestedEbtProps) => { + this.analytics.reportEvent(TrainedModelsTelemetryEventTypes.MODEL_TESTED, eventProps); + }; } diff --git a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_service.ts b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_service.ts index 6afa8b793e17e..aa06320db68ad 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_service.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/telemetry_service.ts @@ -23,6 +23,7 @@ export class TelemetryService { this.analytics = analytics; analytics.registerEventType(trainedModelsEbtEvents.trainedModelsDeploymentCreatedEventType); + analytics.registerEventType(trainedModelsEbtEvents.trainedModelsModelTestedEventType); } public start(): ITelemetryClient { diff --git a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/types.ts b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/types.ts index 4b9de880c73bc..6ec8969dc82d8 100644 --- a/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/types.ts +++ b/x-pack/platform/plugins/shared/ml/public/application/services/telemetry/types.ts @@ -6,6 +6,7 @@ */ import type { RootSchema } from '@kbn/core/public'; +import type { TrainedModelType } from '@kbn/ml-trained-models-utils'; export interface TrainedModelsDeploymentEbtProps { model_id: string; @@ -18,15 +19,29 @@ export interface TrainedModelsDeploymentEbtProps { vcpu_usage: 'low' | 'medium' | 'high'; } +export interface TrainedModelsModelTestedEbtProps { + model_id: string; + model_type?: TrainedModelType; + task_type?: string; + result: 'success' | 'failure'; +} + export enum TrainedModelsTelemetryEventTypes { DEPLOYMENT_CREATED = 'Trained Models Deployment Created', + MODEL_TESTED = 'Trained Model Tested', } -export interface TrainedModelsTelemetryEvent { - eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED; - schema: RootSchema; -} +export type TrainedModelsTelemetryEvent = + | { + eventType: TrainedModelsTelemetryEventTypes.DEPLOYMENT_CREATED; + schema: RootSchema; + } + | { + eventType: TrainedModelsTelemetryEventTypes.MODEL_TESTED; + schema: RootSchema; + }; export interface ITelemetryClient { trackTrainedModelsDeploymentCreated: (eventProps: TrainedModelsDeploymentEbtProps) => void; + trackTrainedModelsModelTested: (eventProps: TrainedModelsModelTestedEbtProps) => void; }