Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/91224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 91224
summary: Allow NLP truncate option to be updated when span is set
area: Machine Learning
type: bug
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public abstract class AbstractTokenizationUpdate implements TokenizationUpdate {

private final Tokenization.Truncate truncate;
private final Integer span;

public AbstractTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
}

public AbstractTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

public Integer getSpan() {
return span;
}

public Tokenization.Truncate getTruncate() {
return truncate;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o instanceof AbstractTokenizationUpdate == false) {
return false;
}
AbstractTokenizationUpdate that = (AbstractTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,17 @@

package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.Version;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class BertTokenizationUpdate implements TokenizationUpdate {
public class BertTokenizationUpdate extends AbstractTokenizationUpdate {

public static final ParseField NAME = BertTokenization.NAME;

Expand All @@ -39,21 +35,12 @@ public static BertTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public BertTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public BertTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
if (in.getVersion().onOrAfter(Version.V_8_2_0)) {
this.span = in.readOptionalInt();
} else {
this.span = null;
}
super(in);
}

@Override
Expand All @@ -66,65 +53,41 @@ public Tokenization apply(Tokenization originalConfig) {
);
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (isNoop()) {
return originalConfig;
}

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
getTruncate(),
null
);
}

return new BertTokenization(
originalConfig.doLowerCase(),
originalConfig.withSpecialTokens(),
originalConfig.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(getSpan()).orElse(originalConfig.getSpan())
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
if (out.getVersion().onOrAfter(Version.V_8_2_0)) {
out.writeOptionalInt(span);
}
}

@Override
public String getName() {
return BertTokenization.NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
BertTokenizationUpdate that = (BertTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public static RobertaTokenization fromXContent(XContentParser parser, boolean le

private final boolean addPrefixSpace;

private RobertaTokenization(
public RobertaTokenization(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is there a constructor that doesn't expect doLowerCase for RobertaTokenization? The only public constructor before this change forces lower case to false. This seems something worth clarifying as we're refactoring the area.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reverted this change as it was only a convenience for a test. The value of doLowerCase should always be false.

@Nullable Boolean doLowerCase,
@Nullable Boolean withSpecialTokens,
@Nullable Integer maxSequenceLength,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
package org.elasticsearch.xpack.core.ml.inference.trainedmodel;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

import java.io.IOException;
import java.util.Objects;
import java.util.Optional;

public class RobertaTokenizationUpdate implements TokenizationUpdate {
public class RobertaTokenizationUpdate extends AbstractTokenizationUpdate {
public static final ParseField NAME = new ParseField(RobertaTokenization.NAME);

public static ConstructingObjectParser<RobertaTokenizationUpdate, Void> PARSER = new ConstructingObjectParser<>(
Expand All @@ -38,17 +34,12 @@ public static RobertaTokenizationUpdate fromXContent(XContentParser parser) {
return PARSER.apply(parser, null);
}

private final Tokenization.Truncate truncate;
private final Integer span;

public RobertaTokenizationUpdate(@Nullable Tokenization.Truncate truncate, @Nullable Integer span) {
this.truncate = truncate;
this.span = span;
super(truncate, span);
}

public RobertaTokenizationUpdate(StreamInput in) throws IOException {
this.truncate = in.readOptionalEnum(Tokenization.Truncate.class);
this.span = in.readOptionalInt();
super(in);
}

@Override
Expand All @@ -58,12 +49,29 @@ public Tokenization apply(Tokenization originalConfig) {
return robertaTokenization;
}

Tokenization.validateSpanAndTruncate(getTruncate(), getSpan());

if (getTruncate() != null && getTruncate().isInCompatibleWithSpan() == false) {
// When truncate value is incompatible with span wipe out
// the existing span setting to avoid an invalid combination of settings.
// This avoids the user have to set span to the special unset value
return new RobertaTokenization(
robertaTokenization.doLowerCase(),
robertaTokenization.withSpecialTokens(),
robertaTokenization.maxSequenceLength(),
getTruncate(),
null,
robertaTokenization.isAddPrefixSpace()
);
}

return new RobertaTokenization(
robertaTokenization.doLowerCase(),
robertaTokenization.withSpecialTokens(),
robertaTokenization.isAddPrefixSpace(),
robertaTokenization.maxSequenceLength(),
Optional.ofNullable(this.truncate).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.span).orElse(originalConfig.getSpan())
Optional.ofNullable(this.getTruncate()).orElse(originalConfig.getTruncate()),
Optional.ofNullable(this.getSpan()).orElse(originalConfig.getSpan()),
robertaTokenization.isAddPrefixSpace()
);
}
throw ExceptionsHelper.badRequestException(
Expand All @@ -73,50 +81,13 @@ public Tokenization apply(Tokenization originalConfig) {
);
}

@Override
public boolean isNoop() {
return truncate == null && span == null;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
if (truncate != null) {
builder.field(Tokenization.TRUNCATE.getPreferredName(), truncate.toString());
}
if (span != null) {
builder.field(Tokenization.SPAN.getPreferredName(), span);
}
builder.endObject();
return builder;
}

@Override
public String getWriteableName() {
return NAME.getPreferredName();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(truncate);
out.writeOptionalInt(span);
}

@Override
public String getName() {
return NAME.getPreferredName();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
RobertaTokenizationUpdate that = (RobertaTokenizationUpdate) o;
return Objects.equals(truncate, that.truncate) && Objects.equals(span, that.span);
}

@Override
public int hashCode() {
return Objects.hash(truncate, span);
}
}
Loading