Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { ES_FIELD_TYPES } from '@kbn/field-types';
import type { MLHttpFetchError } from '@kbn/ml-error-utils';
import type { trainedModelsApiProvider } from '../../../services/ml_api_service/trained_models';
import { getInferenceInfoComponent } from './inference_info';
import type { ITelemetryClient } from '../../../services/telemetry/types';

export type InferenceType =
| SupportedPytorchTasksType
Expand Down Expand Up @@ -84,7 +85,8 @@ export abstract class InferenceBase<TInferResponse> {
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
protected readonly inputType: INPUT_TYPE,
protected readonly deploymentId: string
protected readonly deploymentId: string,
private readonly telemetryClient: ITelemetryClient
) {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField$.next(this.modelInputField);
Expand Down Expand Up @@ -317,9 +319,14 @@ export abstract class InferenceBase<TInferResponse> {
this.inferenceResult$.next([processedResponse]);
this.setFinished();

this.trackModelTested('success');

return [processedResponse];
} catch (error) {
this.setFinishedWithErrors(error);

this.trackModelTested('failure');

throw error;
}
}
Expand All @@ -336,9 +343,15 @@ export abstract class InferenceBase<TInferResponse> {
const processedResponse = docs.map((d) => processResponse(this.getDocFromResponse(d)));
this.inferenceResult$.next(processedResponse);
this.setFinished();

this.trackModelTested('success');

return processedResponse;
} catch (error) {
this.setFinishedWithErrors(error);

this.trackModelTested('failure');

throw error;
}
}
Expand Down Expand Up @@ -391,4 +404,13 @@ export abstract class InferenceBase<TInferResponse> {
}
return doc;
}

private trackModelTested(result: 'success' | 'failure') {
this.telemetryClient.trackTrainedModelsModelTested({
model_id: this.model.model_id,
model_type: this.model.model_type,
task_type: this.inferenceType,
result,
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getNerOutputComponent } from './ner_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export type FormattedNerResponse = Array<{
value: string;
Expand All @@ -37,9 +38,10 @@ export class NerInference extends InferenceBase<NerResponse> {
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { InferResponse, INPUT_TYPE } from '../inference_base';
import { getQuestionAnsweringInput } from './question_answering_input';
import { getQuestionAnsweringOutputComponent } from './question_answering_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface RawQuestionAnsweringResponse {
inference_results: Array<{
Expand Down Expand Up @@ -63,9 +64,10 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.questionText$.pipe(map((questionText) => questionText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { processResponse, processInferenceResult } from './common';
import { getGeneralInputComponent } from '../text_input';
import { getFillMaskOutputComponent } from './fill_mask_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

const DEFAULT_MASK_TOKEN = '[MASK]';

Expand All @@ -36,9 +37,10 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
const maskToken = model.inference_config?.[this.inferenceType]?.mask_token;
if (maskToken) {
this.maskToken = maskToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { getGeneralInputComponent } from '../text_input';
import { getLangIdentOutputComponent } from './lang_ident_output';
import type { TextClassificationResponse, RawTextClassificationResponse } from './common';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class LangIdentInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType: InferenceType = 'classification';
Expand All @@ -32,9 +33,10 @@ export class LangIdentInference extends InferenceBase<TextClassificationResponse
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '
import { getGeneralInputComponent } from '../text_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class TextClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;
Expand All @@ -31,9 +32,10 @@ export class TextClassificationInference extends InferenceBase<TextClassificatio
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '

import { getZeroShotClassificationInput } from './zero_shot_classification_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class ZeroShotClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION;
Expand All @@ -39,9 +40,10 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.labelsText$.pipe(map((labelsText) => labelsText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getTextEmbeddingOutputComponent } from './text_embedding_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface RawTextEmbeddingResponse {
inference_results: Array<{ predicted_value: number[] }>;
Expand Down Expand Up @@ -43,9 +44,10 @@ export class TextEmbeddingInference extends InferenceBase<TextEmbeddingResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { INPUT_TYPE } from '../inference_base';
import { InferenceBase, type InferResponse } from '../inference_base';
import { getTextExpansionOutputComponent } from './text_expansion_output';
import { getTextExpansionInput } from './text_expansion_input';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface TextExpansionPair {
token: string;
Expand Down Expand Up @@ -53,9 +54,10 @@ export class TextExpansionInference extends InferenceBase<TextExpansionResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.queryText$.pipe(map((questionText) => questionText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
isMlIngestInferenceProcessor,
isMlInferencePipelineInferenceConfig,
} from '../create_pipeline_for_model/get_inference_properties_from_pipeline_config';
import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context';

interface Props {
model: estypes.MlTrainedModelConfig;
Expand All @@ -54,6 +55,7 @@ export const SelectedModel: FC<Props> = ({
setCurrentContext,
}) => {
const { trainedModels } = useMlApi();
const { telemetryClient } = useMlTelemetryClient();

const inferrer = useMemo<InferrerType | undefined>(() => {
const taskType = Object.keys(model.inference_config ?? {})[0];
Expand All @@ -65,22 +67,30 @@ export const SelectedModel: FC<Props> = ({
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
tempInferrer = new NerInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new NerInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
tempInferrer = new TextClassificationInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
tempInferrer = new ZeroShotClassificationInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues) {
const { labels, multi_label: multiLabel } = pipelineConfigValues;
Expand All @@ -91,30 +101,55 @@ export const SelectedModel: FC<Props> = ({
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
tempInferrer = new TextEmbeddingInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextEmbeddingInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
tempInferrer = new FillMaskInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new FillMaskInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
tempInferrer = new QuestionAnsweringInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues?.question) {
tempInferrer.setQuestionText(pipelineConfigValues.question);
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION:
tempInferrer = new TextExpansionInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextExpansionInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
tempInferrer = new LangIdentInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new LangIdentInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
}
if (tempInferrer) {
if (pipelineConfigValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ describe('TrainedModelsService', () => {

mockTelemetryService = {
trackTrainedModelsDeploymentCreated: jest.fn(),
};
} as unknown as jest.Mocked<ITelemetryClient>;

mockTrainedModelsApiService = {
getTrainedModelsList: jest.fn(),
Expand Down
Loading