diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java index 90c2aa7412935..9659ffbe817f1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/InternalInferModelActionRequestTests.java @@ -14,9 +14,22 @@ import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NlpConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdateTests; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ResultsFieldUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdateTests; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdateTests; import java.util.ArrayList; import java.util.List; @@ -44,6 +57,12 @@ private static InferenceConfigUpdate randomInferenceConfigUpdate() { RegressionConfigUpdateTests.randomRegressionConfigUpdate(), ClassificationConfigUpdateTests.randomClassificationConfigUpdate(), ResultsFieldUpdateTests.randomUpdate(), + TextClassificationConfigUpdateTests.randomUpdate(), + TextEmbeddingConfigUpdateTests.randomUpdate(), + NerConfigUpdateTests.randomUpdate(), + FillMaskConfigUpdateTests.randomUpdate(), + ZeroShotClassificationConfigUpdateTests.randomUpdate(), + PassThroughConfigUpdateTests.randomUpdate(), EmptyConfigUpdateTests.testInstance() ); } @@ -68,6 +87,27 @@ protected NamedWriteableRegistry getNamedWriteableRegistry() { @Override protected Request mutateInstanceForVersion(Request instance, Version version) { - return instance; + InferenceConfigUpdate adjustedUpdate; + InferenceConfigUpdate currentUpdate = instance.getUpdate(); + if (currentUpdate instanceof NlpConfigUpdate nlpConfigUpdate) { + if (nlpConfigUpdate instanceof TextClassificationConfigUpdate update) { + adjustedUpdate = TextClassificationConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof TextEmbeddingConfigUpdate update) { + adjustedUpdate = TextEmbeddingConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof NerConfigUpdate update) { + adjustedUpdate = NerConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof FillMaskConfigUpdate update) { + adjustedUpdate = FillMaskConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof ZeroShotClassificationConfigUpdate update) { + adjustedUpdate = ZeroShotClassificationConfigUpdateTests.mutateForVersion(update, version); + } else if (nlpConfigUpdate instanceof PassThroughConfigUpdate update) { + adjustedUpdate = PassThroughConfigUpdateTests.mutateForVersion(update, version); + } else { + throw new IllegalArgumentException("Unknown update [" + currentUpdate.getName() + "]"); + } + } else { + adjustedUpdate = currentUpdate; + } + return new Request(instance.getModelId(), instance.getObjectsToInfer(), adjustedUpdate, instance.isPreviouslyLicensed()); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java index 1486946ee79e4..ed174655e6c99 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/FillMaskConfigUpdateTests.java @@ -25,6 +25,27 @@ public class FillMaskConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static FillMaskConfigUpdate randomUpdate() { + FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setNumTopClasses(randomIntBetween(1, 4)); + } + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static FillMaskConfigUpdate mutateForVersion(FillMaskConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null); + } + return instance; + } + @Override Tuple, FillMaskConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { int topClasses = randomIntBetween(1, 10); @@ -103,25 +124,12 @@ protected Writeable.Reader instanceReader() { @Override protected FillMaskConfigUpdate createTestInstance() { - FillMaskConfigUpdate.Builder builder = new FillMaskConfigUpdate.Builder(); - if (randomBoolean()) { - builder.setNumTopClasses(randomIntBetween(1, 4)); - } - if (randomBoolean()) { - builder.setResultsField(randomAlphaOfLength(8)); - } - if (randomBoolean()) { - builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); - } - return builder.build(); + return randomUpdate(); } @Override protected FillMaskConfigUpdate mutateInstanceForVersion(FillMaskConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new FillMaskConfigUpdate(instance.getNumTopClasses(), instance.getResultsField(), null); - } - return instance; + return mutateForVersion(instance, version); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java index 03fddb7409af6..d1089330287fc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NerConfigUpdateTests.java @@ -26,6 +26,24 @@ public class NerConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static NerConfigUpdate randomUpdate() { + NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static NerConfigUpdate mutateForVersion(NerConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new NerConfigUpdate(instance.getResultsField(), null); + } + return instance; + } + @Override Tuple, NerConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { NerConfigUpdate expected = new NerConfigUpdate("ml-results", expectedTokenization); @@ -86,22 +104,12 @@ protected Writeable.Reader instanceReader() { @Override protected NerConfigUpdate createTestInstance() { - NerConfigUpdate.Builder builder = new NerConfigUpdate.Builder(); - if (randomBoolean()) { - builder.setResultsField(randomAlphaOfLength(8)); - } - if (randomBoolean()) { - builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); - } - return builder.build(); + return randomUpdate(); } @Override protected NerConfigUpdate mutateInstanceForVersion(NerConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new NerConfigUpdate(instance.getResultsField(), null); - } - return instance; + return mutateForVersion(instance, version); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java index 125eba8b31df5..002b9da8a5223 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/PassThroughConfigUpdateTests.java @@ -26,6 +26,24 @@ public class PassThroughConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static PassThroughConfigUpdate randomUpdate() { + PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static PassThroughConfigUpdate mutateForVersion(PassThroughConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new PassThroughConfigUpdate(instance.getResultsField(), null); + } + return instance; + } + @Override Tuple, PassThroughConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { PassThroughConfigUpdate expected = new PassThroughConfigUpdate("ml-results", expectedTokenization); @@ -76,22 +94,12 @@ protected Writeable.Reader instanceReader() { @Override protected PassThroughConfigUpdate createTestInstance() { - PassThroughConfigUpdate.Builder builder = new PassThroughConfigUpdate.Builder(); - if (randomBoolean()) { - builder.setResultsField(randomAlphaOfLength(8)); - } - if (randomBoolean()) { - builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); - } - return builder.build(); + return randomUpdate(); } @Override protected PassThroughConfigUpdate mutateInstanceForVersion(PassThroughConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new PassThroughConfigUpdate(instance.getResultsField(), null); - } - return instance; + return mutateForVersion(instance, version); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java index 7515d74123440..a8c2c871e488b 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextClassificationConfigUpdateTests.java @@ -28,6 +28,35 @@ public class TextClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static TextClassificationConfigUpdate randomUpdate() { + TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setNumTopClasses(randomIntBetween(1, 4)); + } + if (randomBoolean()) { + builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4))); + } + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static TextClassificationConfigUpdate mutateForVersion(TextClassificationConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new TextClassificationConfigUpdate( + instance.getClassificationLabels(), + instance.getNumTopClasses(), + instance.getResultsField(), + null + ); + } + return instance; + } + @Override Tuple, TextClassificationConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { int numClasses = randomIntBetween(1, 10); @@ -159,33 +188,12 @@ protected Writeable.Reader instanceReader() { @Override protected TextClassificationConfigUpdate createTestInstance() { - TextClassificationConfigUpdate.Builder builder = new TextClassificationConfigUpdate.Builder(); - if (randomBoolean()) { - builder.setNumTopClasses(randomIntBetween(1, 4)); - } - if (randomBoolean()) { - builder.setClassificationLabels(randomList(1, 3, () -> randomAlphaOfLength(4))); - } - if (randomBoolean()) { - builder.setResultsField(randomAlphaOfLength(8)); - } - if (randomBoolean()) { - builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); - } - return builder.build(); + return randomUpdate(); } @Override protected TextClassificationConfigUpdate mutateInstanceForVersion(TextClassificationConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new TextClassificationConfigUpdate( - instance.getClassificationLabels(), - instance.getNumTopClasses(), - instance.getResultsField(), - null - ); - } - return instance; + return mutateForVersion(instance, version); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java index 745802e890c4b..987722e291afe 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/TextEmbeddingConfigUpdateTests.java @@ -26,6 +26,24 @@ public class TextEmbeddingConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static TextEmbeddingConfigUpdate randomUpdate() { + TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder(); + if (randomBoolean()) { + builder.setResultsField(randomAlphaOfLength(8)); + } + if (randomBoolean()) { + builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); + } + return builder.build(); + } + + public static TextEmbeddingConfigUpdate mutateForVersion(TextEmbeddingConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new TextEmbeddingConfigUpdate(instance.getResultsField(), null); + } + return instance; + } + @Override Tuple, TextEmbeddingConfigUpdate> fromMapTestInstances(TokenizationUpdate expectedTokenization) { TextEmbeddingConfigUpdate expected = new TextEmbeddingConfigUpdate("ml-results", expectedTokenization); @@ -76,22 +94,12 @@ protected Writeable.Reader instanceReader() { @Override protected TextEmbeddingConfigUpdate createTestInstance() { - TextEmbeddingConfigUpdate.Builder builder = new TextEmbeddingConfigUpdate.Builder(); - if (randomBoolean()) { - builder.setResultsField(randomAlphaOfLength(8)); - } - if (randomBoolean()) { - builder.setTokenizationUpdate(new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null)); - } - return builder.build(); + return randomUpdate(); } @Override protected TextEmbeddingConfigUpdate mutateInstanceForVersion(TextEmbeddingConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new TextEmbeddingConfigUpdate(instance.getResultsField(), null); - } - return instance; + return mutateForVersion(instance, version); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java index 775e5824ea3c0..7aa80885ed7f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ZeroShotClassificationConfigUpdateTests.java @@ -27,6 +27,22 @@ public class ZeroShotClassificationConfigUpdateTests extends AbstractNlpConfigUpdateTestCase { + public static ZeroShotClassificationConfigUpdate randomUpdate() { + return new ZeroShotClassificationConfigUpdate( + randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)), + randomBoolean() ? null : randomBoolean(), + randomBoolean() ? null : randomAlphaOfLength(5), + randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null) + ); + } + + public static ZeroShotClassificationConfigUpdate mutateForVersion(ZeroShotClassificationConfigUpdate instance, Version version) { + if (version.before(Version.V_8_1_0)) { + return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null); + } + return instance; + } + @Override protected boolean supportsUnknownFields() { return false; @@ -49,10 +65,7 @@ protected ZeroShotClassificationConfigUpdate createTestInstance() { @Override protected ZeroShotClassificationConfigUpdate mutateInstanceForVersion(ZeroShotClassificationConfigUpdate instance, Version version) { - if (version.before(Version.V_8_1_0)) { - return new ZeroShotClassificationConfigUpdate(instance.getLabels(), instance.getMultiLabel(), instance.getResultsField(), null); - } - return instance; + return mutateForVersion(instance, version); } @Override @@ -197,12 +210,7 @@ public void testIsNoop() { } public static ZeroShotClassificationConfigUpdate createRandom() { - return new ZeroShotClassificationConfigUpdate( - randomBoolean() ? null : randomList(1, 5, () -> randomAlphaOfLength(10)), - randomBoolean() ? null : randomBoolean(), - randomBoolean() ? null : randomAlphaOfLength(5), - randomBoolean() ? null : new BertTokenizationUpdate(randomFrom(Tokenization.Truncate.values()), null) - ); + return randomUpdate(); } @Override