1414import java .util .function .BiConsumer ;
1515
1616import org .apache .commons .lang3 .StringUtils ;
17+ import org .opensearch .action .get .GetAction ;
18+ import org .opensearch .action .get .GetRequest ;
19+ import org .opensearch .action .get .GetResponse ;
1720import org .opensearch .cluster .service .ClusterService ;
1821import org .opensearch .core .action .ActionListener ;
1922import org .opensearch .env .Environment ;
2427import com .google .common .annotations .VisibleForTesting ;
2528
2629import lombok .extern .log4j .Log4j2 ;
30+ import org .opensearch .neuralsearch .processor .optimization .TextImageEmbeddingInferenceFilter ;
31+ import org .opensearch .transport .client .OpenSearchClient ;
2732
2833/**
2934 * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use,
@@ -35,19 +40,24 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor {
3540 public static final String TYPE = "text_image_embedding" ;
3641 public static final String MODEL_ID_FIELD = "model_id" ;
3742 public static final String EMBEDDING_FIELD = "embedding" ;
43+ public static final boolean DEFAULT_SKIP_EXISTING = false ;
44+ public static final String SKIP_EXISTING = "skip_existing" ;
3845 public static final String FIELD_MAP_FIELD = "field_map" ;
3946 public static final String TEXT_FIELD_NAME = "text" ;
4047 public static final String IMAGE_FIELD_NAME = "image" ;
4148 public static final String INPUT_TEXT = "inputText" ;
4249 public static final String INPUT_IMAGE = "inputImage" ;
50+ private static final String INDEX_FIELD = "_index" ;
51+ private static final String ID_FIELD = "_id" ;
4352 private static final Set <String > VALID_FIELD_NAMES = Set .of (TEXT_FIELD_NAME , IMAGE_FIELD_NAME );
4453
4554 private final String modelId ;
4655 private final String embedding ;
4756 private final Map <String , String > fieldMap ;
48-
57+ private final boolean skipExisting ;
58+ private final OpenSearchClient openSearchClient ;
4959 private final MLCommonsClientAccessor mlCommonsClientAccessor ;
50-
60+ private final TextImageEmbeddingInferenceFilter inferenceFilter ;
5161 private final Environment environment ;
5262 private final ClusterService clusterService ;
5363
@@ -57,6 +67,9 @@ public TextImageEmbeddingProcessor(
5767 final String modelId ,
5868 final String embedding ,
5969 final Map <String , String > fieldMap ,
70+ final boolean skipExisting ,
71+ final TextImageEmbeddingInferenceFilter inferenceFilter ,
72+ final OpenSearchClient openSearchClient ,
6073 final MLCommonsClientAccessor clientAccessor ,
6174 final Environment environment ,
6275 final ClusterService clusterService
@@ -71,6 +84,9 @@ public TextImageEmbeddingProcessor(
7184 this .mlCommonsClientAccessor = clientAccessor ;
7285 this .environment = environment ;
7386 this .clusterService = clusterService ;
87+ this .skipExisting = skipExisting ;
88+ this .inferenceFilter = inferenceFilter ;
89+ this .openSearchClient = openSearchClient ;
7490 }
7591
7692 private void validateEmbeddingConfiguration (final Map <String , String > fieldMap ) {
@@ -109,15 +125,28 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer<Ingest
109125 Map <String , String > inferenceMap = createInferences (knnMap );
110126 if (inferenceMap .isEmpty ()) {
111127 handler .accept (ingestDocument , null );
112- } else {
113- mlCommonsClientAccessor .inferenceSentencesMap (
114- MapInferenceRequest .builder ().modelId (this .modelId ).inputObjects (inferenceMap ).build (),
115- ActionListener .wrap (vectors -> {
116- setVectorFieldsToDocument (ingestDocument , vectors );
117- handler .accept (ingestDocument , null );
118- }, e -> { handler .accept (null , e ); })
119- );
128+ return ;
129+ }
130+ if (skipExisting == false ) {
131+ generateAndSetInference (ingestDocument , inferenceMap , handler );
132+ return ;
120133 }
134+ // if skipExisting flag is turned on, eligible inference text and images will be compared and filtered after embeddings are
135+ // copied
136+ Object index = ingestDocument .getSourceAndMetadata ().get (INDEX_FIELD );
137+ Object id = ingestDocument .getSourceAndMetadata ().get (ID_FIELD );
138+ if (Objects .isNull (index ) || Objects .isNull (id )) {
139+ generateAndSetInference (ingestDocument , inferenceMap , handler );
140+ return ;
141+ }
142+ openSearchClient .execute (
143+ GetAction .INSTANCE ,
144+ new GetRequest (index .toString (), id .toString ()),
145+ ActionListener .wrap (
146+ response -> reuseOrGenerateEmbedding (response , ingestDocument , knnMap , inferenceMap , handler ),
147+ e -> handler .accept (null , e )
148+ )
149+ );
121150 } catch (Exception e ) {
122151 handler .accept (null , e );
123152 }
@@ -174,4 +203,55 @@ Map<String, Object> buildTextEmbeddingResult(final String knnKey, List<Number> m
174203 public String getType () {
175204 return TYPE ;
176205 }
206+
207+ /**
208+ * This method invokes inference call through mlCommonsClientAccessor and populates retrieved embeddings to ingestDocument
209+ *
210+ * @param ingestDocument ingestDocument to populate embeddings to
211+ * @param inferenceMap map indicating the path in ingestDocument to populate embeddings
212+ * @param handler SourceAndMetadataMap of ingestDocument Document
213+ *
214+ */
215+ private void generateAndSetInference (
216+ IngestDocument ingestDocument ,
217+ Map <String , String > inferenceMap ,
218+ BiConsumer <IngestDocument , Exception > handler
219+ ) {
220+ mlCommonsClientAccessor .inferenceSentencesMap (
221+ MapInferenceRequest .builder ().modelId (this .modelId ).inputObjects (inferenceMap ).build (),
222+ ActionListener .wrap (vectors -> {
223+ setVectorFieldsToDocument (ingestDocument , vectors );
224+ handler .accept (ingestDocument , null );
225+ }, e -> { handler .accept (null , e ); })
226+ );
227+ }
228+
229+ // This method validates and filters given knnMap and inferenceMap after response is successfully retrieved from get operation.
230+ private void reuseOrGenerateEmbedding (
231+ GetResponse response ,
232+ IngestDocument ingestDocument ,
233+ Map <String , String > knnMap ,
234+ Map <String , String > inferenceMap ,
235+ BiConsumer <IngestDocument , Exception > handler
236+ ) {
237+ final Map <String , Object > existingDocument = response .getSourceAsMap ();
238+ if (existingDocument == null || existingDocument .isEmpty ()) {
239+ generateAndSetInference (ingestDocument , inferenceMap , handler );
240+ return ;
241+ }
242+ // filter given knnMap by comparing existing document with ingestDocument
243+ Map <String , String > filteredKnnMap = inferenceFilter .filterAndCopyExistingEmbeddings (
244+ ingestDocument ,
245+ existingDocument ,
246+ knnMap ,
247+ embedding
248+ );
249+ // create inference map based on filtered knnMap
250+ Map <String , String > filteredInferenceMap = createInferences (filteredKnnMap );
251+ if (filteredInferenceMap .isEmpty ()) {
252+ handler .accept (ingestDocument , null );
253+ } else {
254+ generateAndSetInference (ingestDocument , filteredInferenceMap , handler );
255+ }
256+ }
177257}
0 commit comments