Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import { ES_FIELD_TYPES } from '@kbn/field-types';
import type { MLHttpFetchError } from '@kbn/ml-error-utils';
import type { trainedModelsApiProvider } from '../../../services/ml_api_service/trained_models';
import { getInferenceInfoComponent } from './inference_info';
import type { ITelemetryClient } from '../../../services/telemetry/types';

export type InferenceType =
| SupportedPytorchTasksType
Expand Down Expand Up @@ -84,7 +85,8 @@ export abstract class InferenceBase<TInferResponse> {
protected readonly trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
protected readonly model: estypes.MlTrainedModelConfig,
protected readonly inputType: INPUT_TYPE,
protected readonly deploymentId: string
protected readonly deploymentId: string,
private readonly telemetryClient: ITelemetryClient
) {
this.modelInputField = model.input?.field_names[0] ?? DEFAULT_INPUT_FIELD;
this.inputField$.next(this.modelInputField);
Expand Down Expand Up @@ -317,9 +319,14 @@ export abstract class InferenceBase<TInferResponse> {
this.inferenceResult$.next([processedResponse]);
this.setFinished();

this.trackModelTested('success');

return [processedResponse];
} catch (error) {
this.setFinishedWithErrors(error);

this.trackModelTested('failure');

throw error;
}
}
Expand All @@ -336,9 +343,15 @@ export abstract class InferenceBase<TInferResponse> {
const processedResponse = docs.map((d) => processResponse(this.getDocFromResponse(d)));
this.inferenceResult$.next(processedResponse);
this.setFinished();

this.trackModelTested('success');

return processedResponse;
} catch (error) {
this.setFinishedWithErrors(error);

this.trackModelTested('failure');

throw error;
}
}
Expand Down Expand Up @@ -392,4 +405,13 @@ export abstract class InferenceBase<TInferResponse> {
}
return doc;
}

private trackModelTested(result: 'success' | 'failure') {
this.telemetryClient.trackTrainedModelsModelTested({
model_id: this.model.model_id,
model_type: this.model.model_type,
task_type: this.inferenceType,
result,
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import { InferenceBase, INPUT_TYPE } from '../inference_base';
import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getNerOutputComponent } from './ner_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export type FormattedNerResponse = Array<{
value: string;
Expand All @@ -37,9 +38,10 @@ export class NerInference extends InferenceBase<NerResponse> {
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { InferResponse, INPUT_TYPE } from '../inference_base';
import { getQuestionAnsweringInput } from './question_answering_input';
import { getQuestionAnsweringOutputComponent } from './question_answering_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface RawQuestionAnsweringResponse {
inference_results: Array<{
Expand Down Expand Up @@ -63,9 +64,10 @@ export class QuestionAnsweringInference extends InferenceBase<QuestionAnsweringR
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.questionText$.pipe(map((questionText) => questionText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import { processResponse, processInferenceResult } from './common';
import { getGeneralInputComponent } from '../text_input';
import { getFillMaskOutputComponent } from './fill_mask_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

const DEFAULT_MASK_TOKEN = '[MASK]';

Expand All @@ -36,9 +37,10 @@ export class FillMaskInference extends InferenceBase<TextClassificationResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);
const maskToken = model.inference_config?.[this.inferenceType]?.mask_token;
if (maskToken) {
this.maskToken = maskToken;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import { getGeneralInputComponent } from '../text_input';
import { getLangIdentOutputComponent } from './lang_ident_output';
import type { TextClassificationResponse, RawTextClassificationResponse } from './common';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class LangIdentInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType: InferenceType = 'classification';
Expand All @@ -32,9 +33,10 @@ export class LangIdentInference extends InferenceBase<TextClassificationResponse
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '
import { getGeneralInputComponent } from '../text_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class TextClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION;
Expand All @@ -31,9 +32,10 @@ export class TextClassificationInference extends InferenceBase<TextClassificatio
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import type { TextClassificationResponse, RawTextClassificationResponse } from '

import { getZeroShotClassificationInput } from './zero_shot_classification_input';
import { getTextClassificationOutputComponent } from './text_classification_output';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export class ZeroShotClassificationInference extends InferenceBase<TextClassificationResponse> {
protected inferenceType = SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION;
Expand All @@ -39,9 +40,10 @@ export class ZeroShotClassificationInference extends InferenceBase<TextClassific
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.labelsText$.pipe(map((labelsText) => labelsText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type { InferResponse } from '../inference_base';
import { getGeneralInputComponent } from '../text_input';
import { getTextEmbeddingOutputComponent } from './text_embedding_output';
import type { trainedModelsApiProvider } from '../../../../services/ml_api_service/trained_models';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface RawTextEmbeddingResponse {
inference_results: Array<{ predicted_value: number[] }>;
Expand Down Expand Up @@ -43,9 +44,10 @@ export class TextEmbeddingInference extends InferenceBase<TextEmbeddingResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import type { INPUT_TYPE } from '../inference_base';
import { InferenceBase, type InferResponse } from '../inference_base';
import { getTextExpansionOutputComponent } from './text_expansion_output';
import { getTextExpansionInput } from './text_expansion_input';
import type { ITelemetryClient } from '../../../../services/telemetry/types';

export interface TextExpansionPair {
token: string;
Expand Down Expand Up @@ -53,9 +54,10 @@ export class TextExpansionInference extends InferenceBase<TextExpansionResponse>
trainedModelsApi: ReturnType<typeof trainedModelsApiProvider>,
model: estypes.MlTrainedModelConfig,
inputType: INPUT_TYPE,
deploymentId: string
deploymentId: string,
telemetryClient: ITelemetryClient
) {
super(trainedModelsApi, model, inputType, deploymentId);
super(trainedModelsApi, model, inputType, deploymentId, telemetryClient);

this.initialize(
[this.queryText$.pipe(map((questionText) => questionText !== ''))],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import {
isMlIngestInferenceProcessor,
isMlInferencePipelineInferenceConfig,
} from '../create_pipeline_for_model/get_inference_properties_from_pipeline_config';
import { useMlTelemetryClient } from '../../contexts/ml/ml_telemetry_context';

interface Props {
model: estypes.MlTrainedModelConfig;
Expand All @@ -54,6 +55,7 @@ export const SelectedModel: FC<Props> = ({
setCurrentContext,
}) => {
const { trainedModels } = useMlApi();
const { telemetryClient } = useMlTelemetryClient();

const inferrer = useMemo<InferrerType | undefined>(() => {
const taskType = Object.keys(model.inference_config ?? {})[0];
Expand All @@ -65,22 +67,30 @@ export const SelectedModel: FC<Props> = ({
if (model.model_type === TRAINED_MODEL_TYPE.PYTORCH) {
switch (taskType) {
case SUPPORTED_PYTORCH_TASKS.NER:
tempInferrer = new NerInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new NerInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_CLASSIFICATION:
tempInferrer = new TextClassificationInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.ZERO_SHOT_CLASSIFICATION:
tempInferrer = new ZeroShotClassificationInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues) {
const { labels, multi_label: multiLabel } = pipelineConfigValues;
Expand All @@ -91,30 +101,55 @@ export const SelectedModel: FC<Props> = ({
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EMBEDDING:
tempInferrer = new TextEmbeddingInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextEmbeddingInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.FILL_MASK:
tempInferrer = new FillMaskInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new FillMaskInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
case SUPPORTED_PYTORCH_TASKS.QUESTION_ANSWERING:
tempInferrer = new QuestionAnsweringInference(
trainedModels,
model,
inputType,
deploymentId
deploymentId,
telemetryClient
);
if (pipelineConfigValues?.question) {
tempInferrer.setQuestionText(pipelineConfigValues.question);
}
break;
case SUPPORTED_PYTORCH_TASKS.TEXT_EXPANSION:
tempInferrer = new TextExpansionInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new TextExpansionInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
break;
default:
break;
}
} else if (model.model_type === TRAINED_MODEL_TYPE.LANG_IDENT) {
tempInferrer = new LangIdentInference(trainedModels, model, inputType, deploymentId);
tempInferrer = new LangIdentInference(
trainedModels,
model,
inputType,
deploymentId,
telemetryClient
);
}
if (tempInferrer) {
if (pipelineConfigValues) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ describe('TrainedModelsService', () => {

mockTelemetryService = {
trackTrainedModelsDeploymentCreated: jest.fn(),
};
} as unknown as jest.Mocked<ITelemetryClient>;

mockTrainedModelsApiService = {
getTrainedModelsList: jest.fn(),
Expand Down
Loading