Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

### Enhancements
- [Semantic Field] Support configuring the auto-generated knn_vector field through the semantic field. ([#1420](https://github.com/opensearch-project/neural-search/pull/1420))
- [Semantic Field] Support configuring the ingest batch size for the semantic field.
- [Semantic Field] Support configuring the ingest batch size for the semantic field. ([#1438](https://github.com/opensearch-project/neural-search/pull/1438))
- [Semantic Field] Allow configuring prune strategies for sparse encoding in semantic fields. ([#1434](https://github.com/opensearch-project/neural-search/pull/1434))

### Bug Fixes
- Fix for collapse bug with knn query not deduplicating results ([#1413](https://github.com/opensearch-project/neural-search/pull/1413))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,10 @@ public class SemanticFieldConstants {
* knn_vector field.
*/
public static final String DENSE_EMBEDDING_CONFIG = "dense_embedding_config";

/**
* Name of the field for sparse encoding config. The config will be used to control how to do sparse encoding.
* {@link org.opensearch.neuralsearch.mapper.dto.SparseEncodingConfig}
*/
public static final String SPARSE_ENCODING_CONFIG = "sparse_encoding_config";
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.index.mapper.WildcardFieldMapper;
import org.opensearch.neuralsearch.constants.MappingConstants;
import org.opensearch.neuralsearch.mapper.dto.SemanticParameters;
import org.opensearch.neuralsearch.mapper.dto.SparseEncodingConfig;

import java.io.IOException;
import java.util.HashMap;
Expand All @@ -39,6 +40,7 @@
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEARCH_MODEL_ID;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_INFO_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_FIELD_SEARCH_ANALYZER;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SPARSE_ENCODING_CONFIG;

/**
* FieldMapper for the semantic field. It will hold a delegate field mapper to delegate the data parsing and query work
Expand Down Expand Up @@ -171,6 +173,20 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
}
}, (v) -> v == null ? null : v.toString());

protected final Parameter<SparseEncodingConfig> sparseEncodingConfig = new Parameter<>(
SPARSE_ENCODING_CONFIG,
false,
() -> null,
(name, ctx, value) -> new SparseEncodingConfig(name, value),
m -> ((SemanticFieldMapper) m).semanticParameters.getSparseEncodingConfig()
).setSerializer((builder, name, value) -> {
if (value == null) {
builder.nullField(name);
} else {
value.toXContent(builder, name);
}
}, (value) -> value == null ? null : value.toString());

@Setter
protected ParametrizedFieldMapper.Builder delegateBuilder;

Expand All @@ -187,7 +203,8 @@ protected List<Parameter<?>> getParameters() {
semanticInfoFieldName,
chunkingEnabled,
semanticFieldSearchAnalyzer,
denseEmbeddingConfig
denseEmbeddingConfig,
sparseEncodingConfig
);
}

Expand Down Expand Up @@ -217,6 +234,7 @@ public SemanticParameters getSemanticParameters() {
.chunkingEnabled(chunkingEnabled.getValue())
.semanticFieldSearchAnalyzer(semanticFieldSearchAnalyzer.getValue())
.denseEmbeddingConfig(denseEmbeddingConfig.getValue())
.sparseEncodingConfig(sparseEncodingConfig.getValue())
.build();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@ public class SemanticParameters {
private final Boolean chunkingEnabled;
private final String semanticFieldSearchAnalyzer;
private final Map<String, Object> denseEmbeddingConfig;
private final SparseEncodingConfig sparseEncodingConfig;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.mapper.dto;

import lombok.Getter;
import lombok.NonNull;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.mapper.MapperParsingException;
import org.opensearch.neuralsearch.util.prune.PruneType;
import org.opensearch.neuralsearch.util.prune.PruneUtils;

import java.io.IOException;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SPARSE_ENCODING_CONFIG;
import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_RATIO_FIELD;
import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_TYPE_FIELD;

@Getter
public class SparseEncodingConfig {
private PruneType pruneType;
private Float pruneRatio;

/**
* Construct SparseEncodingConfig with the json input defined in the sparse encoding config in the index mappings
* @param name parameter name
* @param value parameter value
*/
public SparseEncodingConfig(@NonNull final String name, final Object value) {
if (value instanceof Map == false) {
throw new MapperParsingException(String.format(Locale.ROOT, "[%s] must be a Map", name));
}
final Map<String, Object> config = new HashMap<>((Map<String, Object>) value);
final PruneType pruneType = consumePruneType(config);
final Float pruneRatio = consumePruneRatio(config);

// Check for any unrecognized parameters
if (config.isEmpty() == false) {
throw new MapperParsingException(
String.format(Locale.ROOT, "Unsupported parameters %s in %s", String.join(",", config.keySet()), name)
);
}

// Case: pruneType is null and pruneRatio is null → nothing configured
if (pruneType == null && pruneRatio == null) {
return;
}

// Case: pruneRatio is set but pruneType is null or NONE → invalid
if (pruneRatio != null && (pruneType == null | PruneType.NONE.equals(pruneType))) {
throw new MapperParsingException(
String.format(
Locale.ROOT,
"%s should not be defined when %s is %s or null",
PRUNE_RATIO_FIELD,
PRUNE_TYPE_FIELD,
PruneType.NONE.getValue()
)
);
}

// Case: pruneType is defined and not NONE and pruneRatio is null → missing pruneRatio
if (pruneRatio == null && PruneType.NONE.equals(pruneType) == false) {
throw new MapperParsingException(
String.format(
Locale.ROOT,
"%s is required when %s is defined and not %s",
PRUNE_RATIO_FIELD,
PRUNE_TYPE_FIELD,
PruneType.NONE.getValue()
)
);
}

// Case: pruneType is NONE and pruneRatio is null
if (pruneRatio == null) {
this.pruneType = pruneType;
return;
}

// Case: pruneType is not NONE or null and pruneRatio is not null
if (PruneUtils.isValidPruneRatio(pruneType, pruneRatio) == false) {
throw new MapperParsingException(
String.format(
Locale.ROOT,
"Illegal prune_ratio %f for prune_type: %s. %s",
pruneRatio,
pruneType.getValue(),
PruneUtils.getValidPruneRatioDescription(pruneType)
)
);
}

this.pruneType = pruneType;
this.pruneRatio = pruneRatio;
}

/**
* Construct SparseEncodingConfig with a valid semantic field. Only should be used with a valid semantic field config.
* @param fieldConfig semantic field config
*/
public SparseEncodingConfig(@NonNull final Map<String, Object> fieldConfig) {
if (fieldConfig.containsKey(SPARSE_ENCODING_CONFIG) == false) {
return;
}
final Map<String, Object> sparseEncodingConfig = (Map<String, Object>) fieldConfig.get(SPARSE_ENCODING_CONFIG);
final PruneType pruneType = readPruneType(sparseEncodingConfig);
final Float pruneRatio = readPruneRatio(sparseEncodingConfig);
if (pruneType == null && pruneRatio == null) {
return;
}
this.pruneType = pruneType;
this.pruneRatio = pruneRatio;
}

private Float readPruneRatio(@NonNull final Map<String, Object> config) {
if (config.containsKey(PRUNE_RATIO_FIELD)) {
try {
return Float.parseFloat(config.get(PRUNE_RATIO_FIELD).toString());
} catch (Exception e) {
throw new MapperParsingException(String.format(Locale.ROOT, "[%s] must be a Float", PRUNE_RATIO_FIELD));
}
}
return null;
}

private Float consumePruneRatio(@NonNull final Map<String, Object> config) {
try {
return readPruneRatio(config);
} finally {
config.remove(PRUNE_RATIO_FIELD);
}
}

private PruneType readPruneType(@NonNull final Map<String, Object> config) {
if (config.containsKey(PRUNE_TYPE_FIELD)) {
try {
return PruneType.fromString((String) config.get(PRUNE_TYPE_FIELD));
} catch (Exception e) {
throw new MapperParsingException(
String.format(Locale.ROOT, "Invalid [%s]. Valid values are [%s].", PRUNE_TYPE_FIELD, PruneType.getValidValues())
);
}
}
return null;
}

private PruneType consumePruneType(@NonNull final Map<String, Object> config) {
try {
return readPruneType(config);
} finally {
config.remove(PRUNE_TYPE_FIELD);
}
}

public void toXContent(@NonNull final XContentBuilder builder, String name) throws IOException {
builder.startObject(name);
if (pruneType != null) {
builder.field(PRUNE_TYPE_FIELD, pruneType.getValue());
}
if (pruneRatio != null) {
builder.field(PRUNE_RATIO_FIELD, pruneRatio.floatValue());
}
builder.endObject();
}

@Override
public String toString() {
final Map<String, Object> config = new HashMap<>();
if (pruneType != null) {
config.put(PRUNE_TYPE_FIELD, pruneType.getValue());
}
if (pruneRatio != null) {
config.put(PRUNE_RATIO_FIELD, pruneRatio);
}
return config.toString();
}

@Override
public boolean equals(Object obj) {
if (this == obj) {
return true;
} else if (obj != null && this.getClass() == obj.getClass()) {
SparseEncodingConfig other = (SparseEncodingConfig) obj;
EqualsBuilder equalsBuilder = new EqualsBuilder();
equalsBuilder.append(this.pruneType, other.pruneType);
equalsBuilder.append(this.pruneRatio, other.pruneRatio);
return equalsBuilder.isEquals();
} else {
return false;
}
}

@Override
public int hashCode() {
return (new HashCodeBuilder()).append(this.pruneType).append(this.pruneRatio).toHashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import static org.opensearch.neuralsearch.constants.MappingConstants.PROPERTIES;
import static org.opensearch.neuralsearch.constants.MappingConstants.TYPE;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.DENSE_EMBEDDING_CONFIG;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_FIELD_SEARCH_ANALYZER;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SPARSE_ENCODING_CONFIG;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.EMBEDDING_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_TEXT_FIELD_NAME;
Expand Down Expand Up @@ -51,6 +53,7 @@ public class SemanticInfoConfigBuilder {
private Boolean chunkingEnabled;
private String semanticFieldSearchAnalyzer;
private Map<String, Object> denseEmbeddingConfig;
private boolean sparseEncodingConfigDefined;
private final static List<String> UNSUPPORTED_DENSE_EMBEDDING_CONFIG = List.of(
KNN_VECTOR_DIMENSION_FIELD_NAME,
KNN_VECTOR_DATA_TYPE_FIELD_NAME,
Expand Down Expand Up @@ -111,24 +114,47 @@ public Map<String, Object> build() {
}

private void validate() {
if (semanticFieldSearchAnalyzer != null && RankFeaturesFieldMapper.CONTENT_TYPE.equals(embeddingFieldType) == false) {
if (KNNVectorFieldMapper.CONTENT_TYPE.equals(embeddingFieldType)) {
validateSearchAnalyzerNotDefined();
validateSparseEncodingConfigNotDefined();
}

if (RankFeaturesFieldMapper.CONTENT_TYPE.equals(embeddingFieldType)) {
validateDenseEmbeddingConfigNotDefined();
}
}

private void validateSparseEncodingConfigNotDefined() {
if (sparseEncodingConfigDefined) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Cannot build the semantic info config because the embedding field type %s cannot build with semantic field search analyzer %s",
embeddingFieldType,
semanticFieldSearchAnalyzer
"Cannot build the semantic info config because the dense(text embedding) model cannot support %s.",
SPARSE_ENCODING_CONFIG
)
);
}
}

private void validateDenseEmbeddingConfigNotDefined() {
if (denseEmbeddingConfig != null && RankFeaturesFieldMapper.CONTENT_TYPE.equals(embeddingFieldType)) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Cannot build the semantic info config because %s is not supported by %s.",
DENSE_EMBEDDING_CONFIG,
embeddingFieldType
"Cannot build the semantic info config because %s is not supported by the sparse model.",
DENSE_EMBEDDING_CONFIG
)
);
}
}

private void validateSearchAnalyzerNotDefined() {
if (semanticFieldSearchAnalyzer != null) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
"Cannot build the semantic info config because the dense(text embedding) model cannot support %s",
SEMANTIC_FIELD_SEARCH_ANALYZER
)
);
}
Expand Down Expand Up @@ -343,4 +369,9 @@ public SemanticInfoConfigBuilder denseEmbeddingConfig(final Map<String, Object>
this.denseEmbeddingConfig = denseEmbeddingConfig;
return this;
}

public SemanticInfoConfigBuilder sparseEncodingConfigDefined(final boolean sparseEncodingConfigDefined) {
this.sparseEncodingConfigDefined = sparseEncodingConfigDefined;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import static org.opensearch.neuralsearch.constants.MappingConstants.PROPERTIES;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_INFO_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SPARSE_ENCODING_CONFIG;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.collectSemanticField;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.extractModelIdToFieldPathMap;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.getDenseEmbeddingConfig;
Expand Down Expand Up @@ -221,6 +222,7 @@ private Map<String, Object> createSemanticInfoField(
builder.chunkingEnabled(isChunkingEnabled(fieldConfig, fieldPath));
builder.semanticFieldSearchAnalyzer(getSemanticFieldSearchAnalyzer(fieldConfig, fieldPath));
builder.denseEmbeddingConfig(getDenseEmbeddingConfig(fieldConfig, fieldPath));
builder.sparseEncodingConfigDefined(fieldConfig.containsKey(SPARSE_ENCODING_CONFIG));
return builder.build();
}

Expand Down
Loading
Loading