diff --git a/docs/changelog/91224.yaml b/docs/changelog/91224.yaml new file mode 100644 index 0000000000000..60029a19541bf --- /dev/null +++ b/docs/changelog/91224.yaml @@ -0,0 +1,5 @@ +pr: 91224 +summary: Allow NLP truncate option to be updated when span is set +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractTokenizationUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractTokenizationUpdate.java new file mode 100644 index 0000000000000..b11865768f3f6 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/AbstractTokenizationUpdate.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public abstract class AbstractTokenizationUpdate implements TokenizationUpdate { + + private final Tokenization.Truncate truncate; + private final Integer span; + + protected static void declareCommonParserFields(ConstructingObjectParser parser) { + parser.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE); + parser.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN); + } + + public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) { + this.truncate = truncate; + this.span = span; + } + + public AbstractTokenizationUpdate(StreamInput in) throws IOException { + this.truncate = in.readOptionalEnum(Tokenization.Truncate.class); + if (in.getVersion().onOrAfter(Version.V_8_2_0)) { + this.span = in.readOptionalInt(); + } else { + this.span = null; + } + } + + @Override + public boolean isNoop() { + return truncate == null && span == null; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (truncate != null) { + builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString()); + } + if (span != null) { + builder.field(Tokenization.SPAN.getPreferredName(), span); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(truncate); + if (out.getVersion().onOrAfter(Version.V_8_2_0)) { + out.writeOptionalInt(span); + } + } + + public Integer getSpan() { + return span; + } + + public Tokenization.Truncate getTruncate() { + return truncate; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o instanceof AbstractTokenizationUpdate == false) { + return false; + } + AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o; + return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span); + } + + @Override + public int hashCode() { + return Objects.hash(truncate, span); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java index 95db0363eefc1..3bda6f0070f03 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdate.java @@ -7,21 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.Objects; import java.util.Optional; -public class BertTokenizationUpdate implements TokenizationUpdate { +public class BertTokenizationUpdate extends AbstractTokenizationUpdate { public static final ParseField NAME = BertTokenization.NAME; @@ -31,29 +27,19 @@ public class BertTokenizationUpdate implements TokenizationUpdate { ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN); + declareCommonParserFields(PARSER); } public static BertTokenizationUpdate fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private final Tokenization.Truncate truncate; - private final Integer span; - public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) { - this.truncate = truncate; - this.span = span; + super(truncate, span); } public BertTokenizationUpdate(StreamInput in) throws IOException { - this.truncate = in.readOptionalEnum(Tokenization.Truncate.class); - if (in.getVersion().onOrAfter(Version.V_8_2_0)) { - this.span = in.readOptionalInt(); - } else { - this.span = null; - } + super(in); } @Override @@ -66,65 +52,41 @@ public Tokenization apply(Tokenization originalConfig) { ); } + Tokenization.validateSpanAndTruncate(getTruncate(), getSpan()); + if (isNoop()) { return originalConfig; } + if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) { + // When truncate value is incompatible with span wipe out + // the existing span setting to avoid an invalid combination of settings. + // This avoids the user have to set span to the special unset value + return new BertTokenization( + originalConfig.doLowerCase(), + originalConfig.withSpecialTokens(), + originalConfig.maxSequenceLength(), + getTruncate(), + null + ); + } + return new BertTokenization( originalConfig.doLowerCase(), originalConfig.withSpecialTokens(), originalConfig.maxSequenceLength(), - Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()), - Optional.ofNullable(this.span).orElse(originalConfig.getSpan()) + Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()), + Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan()) ); } - @Override - public boolean isNoop() { - return truncate == null && span == null; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (truncate != null) { - builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString()); - } - if (span != null) { - builder.field(Tokenization.SPAN.getPreferredName(), span); - } - builder.endObject(); - return builder; - } - @Override public String getWriteableName() { return BertTokenization.NAME.getPreferredName(); } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(truncate); - if (out.getVersion().onOrAfter(Version.V_8_2_0)) { - out.writeOptionalInt(span); - } - } - @Override public String getName() { return BertTokenization.NAME.getPreferredName(); } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - BertTokenizationUpdate that = (BertTokenizationUpdate) o; - return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span); - } - - @Override - public int hashCode() { - return Objects.hash(truncate, span); - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java index a94c3bba1985e..0708658db320e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdate.java @@ -7,21 +7,17 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; -import org.elasticsearch.Version; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.Objects; import java.util.Optional; -public class MPNetTokenizationUpdate implements TokenizationUpdate { +public class MPNetTokenizationUpdate extends AbstractTokenizationUpdate { public static final ParseField NAME = MPNetTokenization.NAME; @@ -31,29 +27,19 @@ public class MPNetTokenizationUpdate implements TokenizationUpdate { ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN); + declareCommonParserFields(PARSER); } public static MPNetTokenizationUpdate fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private final Tokenization.Truncate truncate; - private final Integer span; - public MPNetTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) { - this.truncate = truncate; - this.span = span; + super(truncate, span); } public MPNetTokenizationUpdate(StreamInput in) throws IOException { - this.truncate = in.readOptionalEnum(Tokenization.Truncate.class); - if (in.getVersion().onOrAfter(Version.V_8_2_0)) { - this.span = in.readOptionalInt(); - } else { - this.span = null; - } + super(in); } @Override @@ -70,61 +56,35 @@ public Tokenization apply(Tokenization originalConfig) { return originalConfig; } + if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) { + // When truncate value is incompatible with span wipe out + // the existing span setting to avoid an invalid combination of settings. + // This avoids the user have to set span to the special unset value + return new MPNetTokenization( + originalConfig.doLowerCase(), + originalConfig.withSpecialTokens(), + originalConfig.maxSequenceLength(), + getTruncate(), + null + ); + } + return new MPNetTokenization( originalConfig.doLowerCase(), originalConfig.withSpecialTokens(), originalConfig.maxSequenceLength(), - Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()), - Optional.ofNullable(this.span).orElse(originalConfig.getSpan()) + Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()), + Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan()) ); } - @Override - public boolean isNoop() { - return truncate == null && span == null; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - if (truncate != null) { - builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString()); - } - if (span != null) { - builder.field(Tokenization.SPAN.getPreferredName(), span); - } - builder.endObject(); - return builder; - } - @Override public String getWriteableName() { return MPNetTokenization.NAME.getPreferredName(); } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(truncate); - if (out.getVersion().onOrAfter(Version.V_8_2_0)) { - out.writeOptionalInt(span); - } - } - @Override public String getName() { return MPNetTokenization.NAME.getPreferredName(); } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - MPNetTokenizationUpdate that = (MPNetTokenizationUpdate) o; - return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span); - } - - @Override - public int hashCode() { - return Objects.hash(truncate, span); - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdate.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdate.java index cef929da8e4da..3763a2350dadc 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdate.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdate.java @@ -8,20 +8,16 @@ package org.elasticsearch.xpack.core.ml.inference.trainedmodel; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ParseField; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; -import java.util.Objects; import java.util.Optional; -public class RobertaTokenizationUpdate implements TokenizationUpdate { +public class RobertaTokenizationUpdate extends AbstractTokenizationUpdate { public static final ParseField NAME = new ParseField(RobertaTokenization.NAME); public static ConstructingObjectParser PARSER = new ConstructingObjectParser<>( @@ -30,25 +26,19 @@ public class RobertaTokenizationUpdate implements TokenizationUpdate { ); static { - PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), Tokenization.TRUNCATE); - PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), Tokenization.SPAN); + declareCommonParserFields(PARSER); } public static RobertaTokenizationUpdate fromXContent(XContentParser parser) { return PARSER.apply(parser, null); } - private final Tokenization.Truncate truncate; - private final Integer span; - public RobertaTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) { - this.truncate = truncate; - this.span = span; + super(truncate, span); } public RobertaTokenizationUpdate(StreamInput in) throws IOException { - this.truncate = in.readOptionalEnum(Tokenization.Truncate.class); - this.span = in.readOptionalInt(); + super(in); } @Override @@ -58,12 +48,27 @@ public Tokenization apply(Tokenization originalConfig) { return robertaTokenization; } + Tokenization.validateSpanAndTruncate(getTruncate(), getSpan()); + + if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) { + // When truncate value is incompatible with span wipe out + // the existing span setting to avoid an invalid combination of settings. + // This avoids the user have to set span to the special unset value + return new RobertaTokenization( + robertaTokenization.withSpecialTokens(), + robertaTokenization.isAddPrefixSpace(), + robertaTokenization.maxSequenceLength(), + getTruncate(), + null + ); + } + return new RobertaTokenization( robertaTokenization.withSpecialTokens(), robertaTokenization.isAddPrefixSpace(), robertaTokenization.maxSequenceLength(), - Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()), - Optional.ofNullable(this.span).orElse(originalConfig.getSpan()) + Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()), + Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan()) ); } throw ExceptionsHelper.badRequestException( @@ -73,50 +78,13 @@ public Tokenization apply(Tokenization originalConfig) { ); } - @Override - public boolean isNoop() { - return truncate == null && span == null; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { - builder.startObject(); - if (truncate != null) { - builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString()); - } - if (span != null) { - builder.field(Tokenization.SPAN.getPreferredName(), span); - } - builder.endObject(); - return builder; - } - @Override public String getWriteableName() { return NAME.getPreferredName(); } - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalEnum(truncate); - out.writeOptionalInt(span); - } - @Override public String getName() { return NAME.getPreferredName(); } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - RobertaTokenizationUpdate that = (RobertaTokenizationUpdate) o; - return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span); - } - - @Override - public int hashCode() { - return Objects.hash(truncate, span); - } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java index 174e0c0e0f92a..b4b1eb9e84725 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/Tokenization.java @@ -27,7 +27,16 @@ public abstract class Tokenization implements NamedXContentObject, NamedWriteabl public enum Truncate { FIRST, SECOND, - NONE; + NONE { + @Override + public boolean isInCompatibleWithSpan() { + return false; + } + }; + + public boolean isInCompatibleWithSpan() { + return true; + } public static Truncate fromString(String value) { return valueOf(value.toUpperCase(Locale.ROOT)); @@ -50,7 +59,7 @@ public String toString() { private static final boolean DEFAULT_DO_LOWER_CASE = false; private static final boolean DEFAULT_WITH_SPECIAL_TOKENS = true; private static final Truncate DEFAULT_TRUNCATION = Truncate.FIRST; - private static final int DEFAULT_SPAN = -1; + private static final int UNSET_SPAN_VALUE = -1; static void declareCommonFields(ConstructingObjectParser parser) { parser.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), DO_LOWER_CASE); @@ -61,7 +70,7 @@ static void declareCommonFields(ConstructingObjectParse } public static BertTokenization createDefault() { - return new BertTokenization(null, null, null, Tokenization.DEFAULT_TRUNCATION, DEFAULT_SPAN); + return new BertTokenization(null, null, null, Tokenization.DEFAULT_TRUNCATION, UNSET_SPAN_VALUE); } protected final boolean doLowerCase; @@ -84,10 +93,14 @@ public static BertTokenization createDefault() { this.withSpecialTokens = Optional.ofNullable(withSpecialTokens).orElse(DEFAULT_WITH_SPECIAL_TOKENS); this.maxSequenceLength = Optional.ofNullable(maxSequenceLength).orElse(DEFAULT_MAX_SEQUENCE_LENGTH); this.truncate = Optional.ofNullable(truncate).orElse(DEFAULT_TRUNCATION); - this.span = Optional.ofNullable(span).orElse(DEFAULT_SPAN); - if (this.span < 0 && this.span != -1) { + this.span = Optional.ofNullable(span).orElse(UNSET_SPAN_VALUE); + if (this.span < 0 && this.span != UNSET_SPAN_VALUE) { throw new IllegalArgumentException( - "[" + SPAN.getPreferredName() + "] must be non-negative to indicate span length or -1 to indicate no windowing should occur" + "[" + + SPAN.getPreferredName() + + "] must be non-negative to indicate span length or [" + + UNSET_SPAN_VALUE + + "] to indicate no windowing should occur" ); } if (this.span > this.maxSequenceLength) { @@ -103,17 +116,7 @@ public static BertTokenization createDefault() { + "]" ); } - if (this.span != -1 && truncate != Truncate.NONE) { - throw new IllegalArgumentException( - "[" - + SPAN.getPreferredName() - + "] must not be provided when [" - + TRUNCATE.getPreferredName() - + "] is not [" - + Truncate.NONE - + "]" - ); - } + validateSpanAndTruncate(truncate, span); } public Tokenization(StreamInput in) throws IOException { @@ -124,7 +127,7 @@ public Tokenization(StreamInput in) throws IOException { if (in.getVersion().onOrAfter(Version.V_8_2_0)) { this.span = in.readInt(); } else { - this.span = -1; + this.span = UNSET_SPAN_VALUE; } } @@ -154,6 +157,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + public static void validateSpanAndTruncate(@Nullable Truncate truncate, @Nullable Integer span) { + if ((span != null && span != UNSET_SPAN_VALUE) && (truncate != null && truncate.isInCompatibleWithSpan())) { + throw new IllegalArgumentException( + "[" + SPAN.getPreferredName() + "] must not be provided when [" + TRUNCATE.getPreferredName() + "] is [" + truncate + "]" + ); + } + } + @Override public boolean equals(Object o) { if (this == o) return true; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdateTests.java new file mode 100644 index 0000000000000..4c34a978204ba --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/BertTokenizationUpdateTests.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import static org.hamcrest.Matchers.sameInstance; + +public class BertTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase { + + public static BertTokenizationUpdate randomInstance() { + Integer span = randomBoolean() ? null : randomIntBetween(8, 128); + Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values()); + + if (truncate != Tokenization.Truncate.NONE) { + span = null; + } + return new BertTokenizationUpdate(truncate, span); + } + + public void testApply() { + expectThrows( + IllegalArgumentException.class, + () -> new BertTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(BertTokenizationTests.createRandom()) + ); + + var updatedSpan = new BertTokenizationUpdate(null, 100).apply( + new BertTokenization(false, false, 512, Tokenization.Truncate.NONE, 50) + ); + assertEquals(new BertTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan); + + var updatedTruncate = new BertTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply( + new BertTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new BertTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate); + + var updatedNone = new BertTokenizationUpdate(Tokenization.Truncate.NONE, null).apply( + new BertTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new BertTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone); + + var unmodified = new BertTokenization(true, true, 512, Tokenization.Truncate.NONE, null); + assertThat(new BertTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified)); + } + + public void testNoop() { + assertTrue(new BertTokenizationUpdate(null, null).isNoop()); + assertFalse(new BertTokenizationUpdate(Tokenization.Truncate.SECOND, null).isNoop()); + assertFalse(new BertTokenizationUpdate(null, 10).isNoop()); + assertFalse(new BertTokenizationUpdate(Tokenization.Truncate.NONE, 10).isNoop()); + } + + @Override + protected Writeable.Reader instanceReader() { + return BertTokenizationUpdate::new; + } + + @Override + protected BertTokenizationUpdate createTestInstance() { + return randomInstance(); + } + + @Override + protected BertTokenizationUpdate mutateInstanceForVersion(BertTokenizationUpdate instance, Version version) { + if (version.before(Version.V_8_2_0)) { + return new BertTokenizationUpdate(instance.getTruncate(), null); + } + + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdateTests.java new file mode 100644 index 0000000000000..891eba7116851 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/MPNetTokenizationUpdateTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import static org.hamcrest.Matchers.sameInstance; + +public class MPNetTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase { + + public static MPNetTokenizationUpdate randomInstance() { + Integer span = randomBoolean() ? null : randomIntBetween(8, 128); + Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values()); + + if (truncate != Tokenization.Truncate.NONE) { + span = null; + } + return new MPNetTokenizationUpdate(truncate, span); + } + + public void testApply() { + expectThrows( + IllegalArgumentException.class, + () -> new MPNetTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(MPNetTokenizationTests.createRandom()) + ); + + var updatedSpan = new MPNetTokenizationUpdate(null, 100).apply( + new MPNetTokenization(false, false, 512, Tokenization.Truncate.NONE, 50) + ); + assertEquals(new MPNetTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan); + + var updatedTruncate = new MPNetTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply( + new MPNetTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new MPNetTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate); + + var updatedNone = new MPNetTokenizationUpdate(Tokenization.Truncate.NONE, null).apply( + new MPNetTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new MPNetTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone); + + var unmodified = new MPNetTokenization(true, true, 512, Tokenization.Truncate.NONE, null); + assertThat(new MPNetTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified)); + } + + @Override + protected Writeable.Reader instanceReader() { + return MPNetTokenizationUpdate::new; + } + + @Override + protected MPNetTokenizationUpdate createTestInstance() { + return randomInstance(); + } + + @Override + protected MPNetTokenizationUpdate mutateInstanceForVersion(MPNetTokenizationUpdate instance, Version version) { + if (version.before(Version.V_8_2_0)) { + return new MPNetTokenizationUpdate(instance.getTruncate(), null); + } + + return instance; + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdateTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdateTests.java new file mode 100644 index 0000000000000..5addb4eec5cee --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RobertaTokenizationUpdateTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.trainedmodel; + +import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import static org.hamcrest.Matchers.sameInstance; + +public class RobertaTokenizationUpdateTests extends AbstractBWCWireSerializationTestCase { + + public static RobertaTokenizationUpdate randomInstance() { + Integer span = randomBoolean() ? null : randomIntBetween(8, 128); + Tokenization.Truncate truncate = randomBoolean() ? null : randomFrom(Tokenization.Truncate.values()); + + if (truncate != Tokenization.Truncate.NONE) { + span = null; + } + return new RobertaTokenizationUpdate(truncate, span); + } + + public void testApply() { + expectThrows( + IllegalArgumentException.class, + () -> new RobertaTokenizationUpdate(Tokenization.Truncate.SECOND, 100).apply(RobertaTokenizationTests.createRandom()) + ); + + var updatedSpan = new RobertaTokenizationUpdate(null, 100).apply( + new RobertaTokenization(false, false, 512, Tokenization.Truncate.NONE, 50) + ); + assertEquals(new RobertaTokenization(false, false, 512, Tokenization.Truncate.NONE, 100), updatedSpan); + + var updatedTruncate = new RobertaTokenizationUpdate(Tokenization.Truncate.FIRST, null).apply( + new RobertaTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new RobertaTokenization(true, true, 512, Tokenization.Truncate.FIRST, null), updatedTruncate); + + var updatedNone = new RobertaTokenizationUpdate(Tokenization.Truncate.NONE, null).apply( + new RobertaTokenization(true, true, 512, Tokenization.Truncate.SECOND, null) + ); + assertEquals(new RobertaTokenization(true, true, 512, Tokenization.Truncate.NONE, null), updatedNone); + + var unmodified = new RobertaTokenization(true, true, 512, Tokenization.Truncate.NONE, null); + assertThat(new RobertaTokenizationUpdate(null, null).apply(unmodified), sameInstance(unmodified)); + } + + @Override + protected Writeable.Reader instanceReader() { + return RobertaTokenizationUpdate::new; + } + + @Override + protected RobertaTokenizationUpdate createTestInstance() { + return randomInstance(); + } + + @Override + protected RobertaTokenizationUpdate mutateInstanceForVersion(RobertaTokenizationUpdate instance, Version version) { + if (version.before(Version.V_8_2_0)) { + return new RobertaTokenizationUpdate(instance.getTruncate(), null); + } + + return instance; + } +}