Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
- Implement analyzer based neural sparse query ([#1088](https://github.com/opensearch-project/neural-search/pull/1088) [#1279](https://github.com/opensearch-project/neural-search/pull/1279))
- [Semantic Field] Add semantic mapping transformer. ([#1276](https://github.com/opensearch-project/neural-search/pull/1276))
- [Semantic Field] Add semantic ingest processor. ([#1309](https://github.com/opensearch-project/neural-search/pull/1309))

### Enhancements

Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ dependencies {
testFixturesImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
testImplementation fileTree(dir: knnJarDirectory, include: ["opensearch-knn-${opensearch_build}.jar", "remote-index-build-client-${opensearch_build}.jar"])
testImplementation "org.opensearch.plugin:parent-join-client:${opensearch_version}"
testImplementation 'org.assertj:assertj-core:3.24.2'
}

// In order to add the jar to the classpath, we need to unzip the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,9 @@ public class MappingConstants {
* Name for properties. An object field will define subfields as properties.
*/
public static final String PROPERTIES = "properties";

/**
* Separator in a field path.
*/
public static final String PATH_SEPARATOR = ".";
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,15 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import static org.opensearch.neuralsearch.constants.MappingConstants.PROPERTIES;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.MODEL_ID;
import static org.opensearch.neuralsearch.constants.SemanticFieldConstants.SEMANTIC_INFO_FIELD_NAME;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.collectSemanticField;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.extractModelIdToFieldPathMap;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.getProperties;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.validateModelId;
import static org.opensearch.neuralsearch.util.SemanticMappingUtils.validateSemanticInfoFieldName;

/**
* SemanticMappingTransformer transforms the index mapping for the semantic field to auto add the semantic info fields
Expand Down Expand Up @@ -168,101 +166,17 @@ private void validateSemanticFields(@NonNull final Map<String, Map<String, Objec
}
}

private String validateModelId(@NonNull final String semanticFieldPath, @NonNull final Map<String, Object> semanticFieldConfig) {
Object modelId = semanticFieldConfig.get(SemanticFieldConstants.MODEL_ID);
if (modelId == null) {
return String.format(Locale.ROOT, "%s is required for the semantic field at %s", MODEL_ID, semanticFieldPath);
}

if (modelId instanceof String == false || ((String) modelId).isEmpty()) {
return String.format(Locale.ROOT, "%s should be a non-empty string for the semantic field at %s", MODEL_ID, semanticFieldPath);
}

return null;
}

private String validateSemanticInfoFieldName(
@NonNull final String semanticFieldPath,
@NonNull final Map<String, Object> semanticFieldConfig
) {
if (semanticFieldConfig.containsKey(SEMANTIC_INFO_FIELD_NAME)) {
final Object semanticInfoFieldName = semanticFieldConfig.get(SEMANTIC_INFO_FIELD_NAME);
if (semanticInfoFieldName instanceof String semanticInfoFieldNameStr) {
if (semanticInfoFieldNameStr.isEmpty()) {
return String.format(
Locale.ROOT,
"%s cannot be an empty string for the semantic field at %s",
SEMANTIC_INFO_FIELD_NAME,
semanticFieldPath

);
}

// OpenSearch allows to define a field name with "." in the index mapping and will unflatten it later
// but in our case it's not necessary to support "." in the custom semantic info field name. So add this
// validation to block it.
if (semanticInfoFieldNameStr.contains(".")) {
return String.format(
Locale.ROOT,
"%s should not contain '.' for the semantic field at %s",
SEMANTIC_INFO_FIELD_NAME,
semanticFieldPath

);
}
} else {
return String.format(
Locale.ROOT,
"%s should be a non-empty string for the semantic field at %s",
SEMANTIC_INFO_FIELD_NAME,
semanticFieldPath

);
}
}
// SEMANTIC_INFO_FIELD_NAME is an optional field. If it does not exist we simply return null to show no error.
return null;
}

private void fetchModelAndModifyMapping(
@NonNull final Map<String, Map<String, Object>> semanticFieldPathToConfigMap,
@NonNull final Map<String, Object> mappings,
@NonNull final ActionListener<Void> listener
) {
final Map<String, List<String>> modelIdToFieldPathMap = extractModelIdToFieldPathMap(semanticFieldPathToConfigMap);
if (modelIdToFieldPathMap.isEmpty()) {
listener.onResponse(null);
}
final AtomicInteger counter = new AtomicInteger(modelIdToFieldPathMap.size());
final AtomicBoolean hasError = new AtomicBoolean(false);
final List<String> errors = new ArrayList<>();
final Map<String, MLModel> modelIdToConfigMap = new ConcurrentHashMap<>(modelIdToFieldPathMap.size());

// We can have multiple semantic fields with different model ids, and we should get model config for each model
for (String modelId : modelIdToFieldPathMap.keySet()) {
mlClientAccessor.getModel(modelId, ActionListener.wrap(mlModel -> {
modelIdToConfigMap.put(modelId, mlModel);
if (counter.decrementAndGet() == 0) {
try {
if (hasError.get()) {
listener.onFailure(new RuntimeException(String.join("; ", errors)));
} else {
modifyMappings(modelIdToConfigMap, mappings, modelIdToFieldPathMap, semanticFieldPathToConfigMap);
listener.onResponse(null);
}
} catch (Exception e) {
errors.add(e.getMessage());
listener.onFailure(new RuntimeException(String.join("; ", errors)));
}
}
}, e -> {
hasError.set(true);
errors.add(e.getMessage());
if (counter.decrementAndGet() == 0) {
listener.onFailure(new RuntimeException(String.join("; ", errors)));
}
}));
}
mlClientAccessor.getModels(modelIdToFieldPathMap.keySet(), modelIdToConfigMap -> {
modifyMappings(modelIdToConfigMap, mappings, modelIdToFieldPathMap, semanticFieldPathToConfigMap);
listener.onResponse(null);
}, listener::onFailure);
}

private void modifyMappings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -316,6 +321,71 @@ public void getModel(@NonNull final String modelId, @NonNull final ActionListene
retryableGetModel(modelId, 0, listener);
}

/**
* Get model info for multiple model ids. It will send multiple getModel requests to get the model info in parallel.
* It will fail if any one of the get model request fail. Only return the success result if all model info is
* successfully retrieved.
* @param modelIds a set of model ids
* @param onSuccess onSuccess consumer
* @param onFailure onFailure consumer
*/
public void getModels(
@NonNull final Set<String> modelIds,
@NonNull final Consumer<Map<String, MLModel>> onSuccess,
@NonNull final Consumer<Exception> onFailure
) {
if (modelIds.isEmpty()) {
try {
onSuccess.accept(Collections.emptyMap());
} catch (Exception e) {
onFailure.accept(e);
}
return;
}

final Map<String, MLModel> modelMap = new ConcurrentHashMap<>();
final AtomicInteger counter = new AtomicInteger(modelIds.size());
final AtomicBoolean hasError = new AtomicBoolean(false);
final List<String> errors = Collections.synchronizedList(new ArrayList<>());

for (String modelId : modelIds) {
try {
getModel(modelId, ActionListener.wrap(model -> {
modelMap.put(modelId, model);
if (counter.decrementAndGet() == 0) {
if (hasError.get()) {
onFailure.accept(new RuntimeException(String.join(";", errors)));
} else {
try {
onSuccess.accept(modelMap);
} catch (Exception e) {
onFailure.accept(e);
}
}
}
}, e -> { handleGetModelException(hasError, errors, modelId, e, counter, onFailure); }));
} catch (Exception e) {
handleGetModelException(hasError, errors, modelId, e, counter, onFailure);
}
}

}

private void handleGetModelException(
AtomicBoolean hasError,
List<String> errors,
String modelId,
Exception e,
AtomicInteger counter,
@NonNull Consumer<Exception> onFailure
) {
hasError.set(true);
errors.add("Failed to fetch model [" + modelId + "]: " + e.getMessage());
if (counter.decrementAndGet() == 0) {
onFailure.accept(new RuntimeException(String.join(";", errors)));
}
}

private void retryableGetModel(@NonNull final String modelId, final int retryTime, @NonNull final ActionListener<MLModel> listener) {
mlClient.getModel(
modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.index.mapper.MappingTransformer;
import org.opensearch.neuralsearch.mapper.SemanticFieldMapper;
import org.opensearch.neuralsearch.mappingtransformer.SemanticMappingTransformer;
import org.opensearch.neuralsearch.processor.factory.SemanticFieldProcessorFactory;
import org.opensearch.plugins.MapperPlugin;
import org.opensearch.transport.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -300,4 +301,17 @@ public Map<String, Mapper.TypeParser> getMappers() {
public List<MappingTransformer> getMappingTransformers() {
return List.of(new SemanticMappingTransformer(clientAccessor, xContentRegistry));
}

@Override
public Map<String, Processor.Factory> getSystemIngestProcessors(Processor.Parameters parameters) {
return Map.of(
SemanticFieldProcessorFactory.PROCESSOR_FACTORY_TYPE,
new SemanticFieldProcessorFactory(
clientAccessor,
parameters.env,
parameters.ingestService.getClusterService(),
parameters.analysisRegistry
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.dto;

import lombok.Data;

import java.util.List;

import static org.opensearch.neuralsearch.constants.MappingConstants.PATH_SEPARATOR;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_EMBEDDING_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.CHUNKS_FIELD_NAME;
import static org.opensearch.neuralsearch.constants.SemanticInfoFieldConstants.MODEL_FIELD_NAME;

/**
* SemanticFieldInfo is a data transfer object to help hold semantic field info
*/
@Data
public class SemanticFieldInfo {
/**
* The raw string value of the semantic field
*/
private String value;
/**
* The model id of the semantic field which will be used to generate the embedding
*/
private String modelId;
/**
* The full path to the semantic field
*/
private String fullPath;
/**
* The full path to the semantic info fields
*/
private String semanticInfoFullPath;
/**
* The chunked strings of the original string value of the semantic field
*/
private List<String> chunks;

/**
* @return full path to the chunks field of the semantic field
*/
public String getFullPathForChunks() {
return new StringBuilder().append(semanticInfoFullPath).append(PATH_SEPARATOR).append(CHUNKS_FIELD_NAME).toString();
}

/**
* @param index index of the chunk the embedding is in
* @return full path to the embedding field of the semantic field
*/
public String getFullPathForEmbedding(int index) {
return new StringBuilder().append(semanticInfoFullPath)
.append(PATH_SEPARATOR)
.append(CHUNKS_FIELD_NAME)
.append(PATH_SEPARATOR)
.append(index)
.append(PATH_SEPARATOR)
.append(CHUNKS_EMBEDDING_FIELD_NAME)
.toString();
}

/**
* @return full path to the model info fields
*/
public String getFullPathForModelInfo() {
return new StringBuilder().append(semanticInfoFullPath).append(PATH_SEPARATOR).append(MODEL_FIELD_NAME).toString();
}
}
Loading
Loading