From 6df1c07c12824ab69c890e9eb7c833922beef576 Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Tue, 28 Oct 2025 21:27:31 +0800 Subject: [PATCH 1/6] enhance multimodal embeddings --- .../embedding_transform_multimodal.conf | 78 ++++++ .../embedding/EmbeddingTransform.java | 95 ++++--- .../nlpmodel/embedding/SrcField.java | 50 ++++ .../{FieldSpec.java => SrcFieldSpec.java} | 46 +-- .../nlpmodel/embedding/VectorFieldSpec.java | 101 +++++++ .../multimodal/MultimodalFieldValue.java | 32 +-- .../embedding/remote/doubao/DoubaoModel.java | 75 ++--- .../embedding/DoubaoMultimodalModelTest.java | 265 ++++++++++++------ .../transform/embedding/FieldSpecTest.java | 114 -------- .../embedding/VectorFieldSpecTest.java | 200 +++++++++++++ 10 files changed, 726 insertions(+), 330 deletions(-) create mode 100644 seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java rename seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/{FieldSpec.java => SrcFieldSpec.java} (64%) create mode 100644 seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/VectorFieldSpec.java delete mode 100644 seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java create mode 100644 seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java diff --git a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf index efc72731457..cada8b7ca27 100644 --- a/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf +++ b/seatunnel-e2e/seatunnel-transforms-v2-e2e/seatunnel-transforms-v2-e2e-part-1/src/test/resources/embedding_transform_multimodal.conf @@ -154,6 +154,48 @@ transform { } product_name_vector = product_name + + multi_field_text_vector = [product_name, description] + + multi_field_image_vector = [ + { + field = product_image_url + modality = jpeg + format = url + }, + { + field = thumbnail_image + modality = png + format = url + } + ] + + multi_field_video_vector = [ + { + field = product_video_url + modality = mp4 + format = url + }, + { + field = promotional_video + modality = mov + format = url + } + ] + + multi_field_mix_vector = [ + product_name, + { + field = product_image_url + modality = jpeg + format = url + }, + { + field = product_video_url + modality = mp4 + format = url + } + ] } plugin_output = "multimodal_embedding_output" @@ -219,6 +261,42 @@ sink { } ] }, + { + field_name = multi_field_text_vector + field_type = float_vector + field_value = [ + { + rule_type = NOT_NULL + } + ] + }, + { + field_name = multi_field_image_vector + field_type = float_vector + field_value = [ + { + rule_type = NOT_NULL + } + ] + }, + { + field_name = multi_field_video_vector + field_type = float_vector + field_value = [ + { + rule_type = NOT_NULL + } + ] + }, + { + field_name = multi_field_mix_vector + field_type = float_vector + field_value = [ + { + rule_type = NOT_NULL + } + ] + }, { field_name = category field_type = string diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java index 8857e6264c6..0eab70b5f8a 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/EmbeddingTransform.java @@ -57,16 +57,16 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; @Slf4j public class EmbeddingTransform extends MultipleFieldOutputTransform { private final ReadonlyConfig config; - private List fieldOriginalIndexes; private transient Model model; private Integer dimension; private boolean isMultimodalFields = false; - private Map fieldSpecMap; + private Map> fieldSpecMap; private List fieldNames; private final Map> binaryFileCache = new ConcurrentHashMap<>(); @@ -197,29 +197,33 @@ public void open() { } private void initOutputFields(SeaTunnelRowType inputRowType, ReadonlyConfig config) { - Map fieldSpecMap = new HashMap<>(); - List fieldNames = new ArrayList<>(); Map fieldsConfig = config.get(EmbeddingTransformConfig.VECTORIZATION_FIELDS); if (fieldsConfig == null || fieldsConfig.isEmpty()) { throw new IllegalArgumentException("vectorization_fields configuration is required"); } - for (Map.Entry field : fieldsConfig.entrySet()) { - FieldSpec fieldSpec = new FieldSpec(field); - log.info("Field spec: {}", fieldSpec.toString()); - String srcField = fieldSpec.getFieldName(); - int srcFieldIndex; - try { - srcFieldIndex = inputRowType.indexOf(srcField); - } catch (IllegalArgumentException e) { - throw TransformCommonError.cannotFindInputFieldError(getPluginName(), srcField); - } - if (fieldSpec.isMultimodalField()) { - isMultimodalFields = true; + List fieldNames = new ArrayList<>(); + Map> fieldSpecMap = new HashMap<>(); + for (Map.Entry fieldConfig : fieldsConfig.entrySet()) { + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(fieldConfig); + log.info("Vector field spec: {}", vectorFieldSpec); + List srcFieldNames = + vectorFieldSpec.getSrcFieldSpecs().stream() + .map(SrcFieldSpec::getFieldName) + .collect(Collectors.toList()); + List srcFieldIndexes = new ArrayList<>(); + for (String srcFieldName : srcFieldNames) { + try { + srcFieldIndexes.add(inputRowType.indexOf(srcFieldName)); + } catch (IllegalArgumentException e) { + throw TransformCommonError.cannotFindInputFieldsError( + getPluginName(), srcFieldNames); + } } - fieldSpecMap.put(srcFieldIndex, fieldSpec); - fieldNames.add(field.getKey()); + isMultimodalFields = vectorFieldSpec.isMultimodalField(); + fieldSpecMap.put(vectorFieldSpec, srcFieldIndexes); + fieldNames.add(vectorFieldSpec.getFieldName()); } this.fieldSpecMap = fieldSpecMap; this.fieldNames = fieldNames; @@ -232,19 +236,28 @@ protected Object[] getOutputFieldValues(SeaTunnelRowAccessor inputRow) { if (MetadataUtil.isBinaryFormat(inputRow)) { return vectorizationBinaryRow(inputRow); } - Set fieldOriginalIndexes = fieldSpecMap.keySet(); - Object[] fieldValues = new Object[fieldOriginalIndexes.size()]; - List vectorization; + + Set vectorFieldSpecs = fieldSpecMap.keySet(); + Object[] fieldValues = new Object[vectorFieldSpecs.size()]; int i = 0; - for (Integer fieldOriginalIndex : fieldOriginalIndexes) { - FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex); - Object value = inputRow.getField(fieldOriginalIndex); + for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) { + List srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs(); + List srcFieldIndexes = fieldSpecMap.get(vectorFieldSpec); + List srcFields = new ArrayList<>(); + for (int j = 0; j < srcFieldSpecs.size(); j++) { + srcFields.add( + new SrcField( + srcFieldSpecs.get(j), + inputRow.getField(srcFieldIndexes.get(j)))); + } fieldValues[i++] = - isMultimodalFields ? new MultimodalFieldValue(fieldSpec, value) : value; + isMultimodalFields + ? new MultimodalFieldValue(srcFields) + : srcFields.get(0).getFieldValue(); } - vectorization = model.vectorization(fieldValues); + List vectorization = model.vectorization(fieldValues); return vectorization.toArray(); } catch (Exception e) { throw new RuntimeException("Failed to data vectorization", e); @@ -282,32 +295,34 @@ public boolean isMultimodalFields() { /** Process a row in binary format: [data, relativePath, partIndex] */ private Object[] vectorizationBinaryRow(SeaTunnelRowAccessor inputRow) throws Exception { - byte[] completeData = processBinaryRow(inputRow); if (completeData == null) { return null; } - Set fieldOriginalIndexes = fieldSpecMap.keySet(); - Object[] fieldValues = new Object[fieldOriginalIndexes.size()]; + + Set vectorFieldSpecs = fieldSpecMap.keySet(); + Object[] fieldValues = new Object[vectorFieldSpecs.size()]; int i = 0; - for (Integer fieldOriginalIndex : fieldOriginalIndexes) { - FieldSpec fieldSpec = fieldSpecMap.get(fieldOriginalIndex); - if (fieldSpec.isBinary()) { - fieldValues[i++] = new MultimodalFieldValue(fieldSpec, completeData); - } else { - log.warn( - "Non-binary field {} configured in binary format data", - fieldSpec.getFieldName()); - fieldValues[i++] = null; + for (VectorFieldSpec vectorFieldSpec : vectorFieldSpecs) { + List srcFieldSpecs = vectorFieldSpec.getSrcFieldSpecs(); + List srcFields = new ArrayList<>(); + for (SrcFieldSpec srcFieldSpec : srcFieldSpecs) { + if (srcFieldSpec.isBinary()) { + srcFields.add(new SrcField(srcFieldSpec, completeData)); + } else { + log.warn( + "Non-binary field {} configured in binary format data", + srcFieldSpec.getFieldName()); + } } + fieldValues[i++] = srcFields.isEmpty() ? null : new MultimodalFieldValue(srcFields); } try { return model.vectorization(fieldValues).toArray(); } catch (Exception e) { - throw new RuntimeException( - "Failed to vectorize binary data for file: " + inputRow.toString(), e); + throw new RuntimeException("Failed to vectorize binary data for file: " + inputRow, e); } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java new file mode 100644 index 00000000000..c1fbcdd2b1d --- /dev/null +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.transform.nlpmodel.embedding; + +import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat; + +import lombok.Data; + +import java.io.Serializable; +import java.util.Base64; + +@Data +public class SrcField implements Serializable { + + private static final long serialVersionUID = 1L; + + private SrcFieldSpec fieldSpec; + + private Object fieldValue; + + public SrcField(SrcFieldSpec spec, Object value) { + this.fieldSpec = spec; + this.fieldValue = value; + } + + public String toBase64() { + if (fieldSpec == null || !PayloadFormat.BINARY.equals(fieldSpec.getPayloadFormat())) { + throw new IllegalArgumentException("Payload format must be binary"); + } + if (fieldValue == null) { + throw new IllegalArgumentException("Binary data cannot be null or empty"); + } + return Base64.getEncoder().encodeToString(fieldValue.toString().getBytes()); + } +} diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcFieldSpec.java similarity index 64% rename from seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java rename to seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcFieldSpec.java index 94ee65329eb..b0078ed2736 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/FieldSpec.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcFieldSpec.java @@ -26,7 +26,7 @@ import java.util.Map; @Data -public class FieldSpec implements Serializable { +public class SrcFieldSpec implements Serializable { private static final long serialVersionUID = 1L; @@ -34,40 +34,12 @@ public class FieldSpec implements Serializable { private ModalityType modalityType; private PayloadFormat payloadFormat; - public FieldSpec(String fieldName) { - this.fieldName = fieldName; - this.modalityType = ModalityType.TEXT; - this.payloadFormat = PayloadFormat.TEXT; - } - - public FieldSpec(Map.Entry fieldConfig) { - String outputFieldName = fieldConfig.getKey(); - if (outputFieldName == null) { - throw new IllegalArgumentException("Field spec cannot be null"); - } - Object fieldValue = fieldConfig.getValue(); - try { - if (fieldValue instanceof String) { - parseBasicFieldSpec((String) fieldValue); - } else { - Map fieldSpecConfig = (Map) fieldValue; - parseMultimodalFieldSpec(fieldSpecConfig); - } - } catch (Exception e) { - String errorMessage = - String.format( - "Invalid field spec for output field '%s': %s", - outputFieldName, fieldConfig); - throw new IllegalArgumentException(errorMessage, e); - } - } - /** Parse basic field spec: just the field name, defaults to TEXT modality and default format */ - private void parseBasicFieldSpec(String fieldSpec) { - if (fieldSpec == null || fieldSpec.trim().isEmpty()) { - throw new IllegalArgumentException("Field spec cannot be null or empty"); + public SrcFieldSpec(String fieldName) { + if (fieldName == null || fieldName.trim().isEmpty()) { + throw new IllegalArgumentException("Field name cannot be null or empty"); } - this.fieldName = fieldSpec.trim(); + this.fieldName = fieldName.trim(); this.modalityType = ModalityType.TEXT; this.payloadFormat = PayloadFormat.TEXT; } @@ -76,9 +48,9 @@ private void parseBasicFieldSpec(String fieldSpec) { * Parse multimodal field spec: field name, modality, and format Supports both formats: 1. * Separate modality and format */ - private void parseMultimodalFieldSpec(Map fieldConfig) { + public SrcFieldSpec(Map fieldConfig) { if (fieldConfig == null || fieldConfig.isEmpty()) { - throw new IllegalArgumentException("Field configuration cannot be null or empty"); + throw new IllegalArgumentException("Field config cannot be null or empty"); } Object fieldNameObj = fieldConfig.get("field"); @@ -109,10 +81,6 @@ private void parseMultimodalFieldSpec(Map fieldConfig) { } } - public boolean isMultimodalField() { - return !ModalityType.TEXT.equals(modalityType); - } - public boolean isBinary() { return PayloadFormat.BINARY.equals(payloadFormat); } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/VectorFieldSpec.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/VectorFieldSpec.java new file mode 100644 index 00000000000..bf8bd937f0c --- /dev/null +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/VectorFieldSpec.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.transform.nlpmodel.embedding; + +import org.apache.seatunnel.shade.org.apache.commons.lang3.StringUtils; + +import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType; + +import lombok.Data; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +@Data +public class VectorFieldSpec implements Serializable { + + private static final long serialVersionUID = 1L; + + private String fieldName; + + private List srcFieldSpecs; + + public VectorFieldSpec(Map.Entry fieldConfig) { + this.fieldName = fieldConfig.getKey(); + if (StringUtils.isBlank(fieldName)) { + throw new IllegalArgumentException("Field config name cannot be null or empty"); + } + Object fieldConfigValue = fieldConfig.getValue(); + if (fieldConfigValue == null) { + throw new IllegalArgumentException( + "Field config value cannot be null for field: " + fieldName); + } + + srcFieldSpecs = new ArrayList<>(); + try { + if (fieldConfigValue instanceof String) { + srcFieldSpecs.add(new SrcFieldSpec((String) fieldConfigValue)); + } else if (fieldConfigValue instanceof Map) { + srcFieldSpecs.add(new SrcFieldSpec((Map) fieldConfigValue)); + } else { + List fieldConfigValues = (List) fieldConfigValue; + for (Object fieldConfigValueItem : fieldConfigValues) { + if (fieldConfigValueItem instanceof String) { + srcFieldSpecs.add(new SrcFieldSpec((String) fieldConfigValueItem)); + } else if (fieldConfigValueItem instanceof Map) { + srcFieldSpecs.add( + new SrcFieldSpec((Map) fieldConfigValueItem)); + } else { + String errorMessage = + String.format( + "Invalid field spec for output field '%s': %s", + fieldName, fieldConfig); + throw new IllegalArgumentException(errorMessage); + } + } + } + } catch (Exception e) { + String errorMessage = + String.format( + "Invalid field spec for output field '%s': %s", fieldName, fieldConfig); + throw new IllegalArgumentException(errorMessage, e); + } + } + + public boolean isMultimodalField() { + return srcFieldSpecs.size() > 1 + || srcFieldSpecs.stream() + .anyMatch(f -> !ModalityType.TEXT.equals(f.getModalityType())); + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + VectorFieldSpec that = (VectorFieldSpec) object; + return Objects.equals(fieldName, that.fieldName); + } + + @Override + public int hashCode() { + return Objects.hash(fieldName); + } +} diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java index 01c3e504032..c4d748db5ab 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java @@ -17,13 +17,14 @@ package org.apache.seatunnel.transform.nlpmodel.embedding.multimodal; -import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec; +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcField; +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcFieldSpec; import lombok.Getter; import lombok.extern.slf4j.Slf4j; import java.io.Serializable; -import java.util.Base64; +import java.util.List; @Slf4j @Getter @@ -31,26 +32,28 @@ public class MultimodalFieldValue implements Serializable { private static final long serialVersionUID = 1L; - private final FieldSpec fieldSpec; - private final Object value; + private final List srcFields; - public MultimodalFieldValue(FieldSpec fieldSpec, Object value) { - this.value = value; - fieldSpec.setModalityType(determineModalityType(fieldSpec, value)); - this.fieldSpec = fieldSpec; + public MultimodalFieldValue(List srcFields) { + this.srcFields = srcFields; + for (SrcField srcField : srcFields) { + SrcFieldSpec fieldSpec = srcField.getFieldSpec(); + ModalityType modalityType = determineModalityType(fieldSpec, srcField.getFieldValue()); + fieldSpec.setModalityType(modalityType); + } } /** * Determine the actual modality type based on field spec and value If not binary format, * analyze the value suffix to determine modality type */ - private ModalityType determineModalityType(FieldSpec fieldSpec, Object value) { + private ModalityType determineModalityType(SrcFieldSpec fieldSpec, Object fieldValue) { if (fieldSpec.isBinary()) { return fieldSpec.getModalityType(); } - if (value != null) { - String valueStr = value.toString(); + if (fieldValue != null) { + String valueStr = fieldValue.toString(); ModalityType detectedType = ModalityType.fromFileSuffix(valueStr); if (detectedType != null) { log.debug( @@ -60,11 +63,4 @@ private ModalityType determineModalityType(FieldSpec fieldSpec, Object value) { } return fieldSpec.getModalityType(); } - - public String toBase64() { - if (value == null) { - throw new IllegalArgumentException("Binary data cannot be null or empty"); - } - return Base64.getEncoder().encodeToString(value.toString().getBytes()); - } } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java index 46250ac829c..98493c5a007 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java @@ -23,7 +23,8 @@ import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode; import org.apache.seatunnel.shade.com.google.common.annotations.VisibleForTesting; -import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec; +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcField; +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcFieldSpec; import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType; import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalFieldValue; import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalModel; @@ -40,6 +41,7 @@ import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; public class DoubaoModel extends MultimodalModel { @@ -95,7 +97,7 @@ protected List> textVector(Object[] fields) throws IOException { public List> multimodalVector(Object[] fields) throws IOException { if (singleVectorizedInputNumber > 1) { throw new IllegalArgumentException( - "Doubao does not support batch multimodal vectorization in a single request. "); + "Doubao does not support batch multimodal vectorization in a single request."); } List> vectors = new ArrayList<>(); for (Object field : fields) { @@ -106,12 +108,15 @@ public List> multimodalVector(Object[] fields) throws IOException { @Override public Integer dimension() throws IOException { - return isMultimodalFields - ? multimodalVectorGeneration( - new MultimodalFieldValue( - new FieldSpec(DIMENSION_EXAMPLE), DIMENSION_EXAMPLE)) - .size() - : textVectorGeneration(new Object[] {DIMENSION_EXAMPLE}).get(0).size(); + if (isMultimodalFields) { + SrcField srcField = + new SrcField(new SrcFieldSpec(DIMENSION_EXAMPLE), DIMENSION_EXAMPLE); + return multimodalVectorGeneration( + new MultimodalFieldValue(Collections.singletonList(srcField))) + .size(); + } else { + return textVectorGeneration(new Object[] {DIMENSION_EXAMPLE}).get(0).size(); + } } private List> textVectorGeneration(Object[] fields) throws IOException { @@ -214,35 +219,39 @@ public ObjectNode multimodalBody(MultimodalFieldValue field) { ObjectNode requestNode = OBJECT_MAPPER.createObjectNode(); requestNode.put("model", model); requestNode.put("encoding_format", "float"); - ArrayNode inputDatas = OBJECT_MAPPER.createArrayNode(); - inputDatas.add(inputRawData(field)); - requestNode.set("input", inputDatas); + ArrayNode inputNode = OBJECT_MAPPER.createArrayNode(); + inputNode.addAll(inputRawData(field)); + requestNode.set("input", inputNode); return requestNode; } - protected ObjectNode inputRawData(MultimodalFieldValue field) { - ObjectNode rawDataNode = OBJECT_MAPPER.createObjectNode(); - FieldSpec fieldSpec = field.getFieldSpec(); - String fieldValue = field.getValue().toString().trim(); - ModalityType fieldSpecModalityType = fieldSpec.getModalityType(); - String modalityParamName = getModalityParamName(fieldSpecModalityType); - rawDataNode.put("type", modalityParamName); - if (ModalityType.TEXT == fieldSpecModalityType) { - rawDataNode.put(modalityParamName, fieldValue); - return rawDataNode; - } - - if (fieldSpec.isBinary()) { - fieldValue = - String.format( - BASE64_PARAM_TEMPLATE, - fieldSpecModalityType.getGroup().name().toLowerCase(), - fieldSpecModalityType.getName(), - field.toBase64()); + protected List inputRawData(MultimodalFieldValue field) { + List rawDataNodes = new ArrayList<>(); + List srcFields = field.getSrcFields(); + for (SrcField srcField : srcFields) { + ObjectNode rawDataNode = OBJECT_MAPPER.createObjectNode(); + String fieldValue = srcField.getFieldValue().toString().trim(); + ModalityType fieldSpecModalityType = srcField.getFieldSpec().getModalityType(); + String modalityParamName = getModalityParamName(fieldSpecModalityType); + rawDataNode.put("type", modalityParamName); + if (ModalityType.TEXT == fieldSpecModalityType) { + rawDataNode.put(modalityParamName, fieldValue); + rawDataNodes.add(rawDataNode); + continue; + } + if (srcField.getFieldSpec().isBinary()) { + fieldValue = + String.format( + BASE64_PARAM_TEMPLATE, + fieldSpecModalityType.getGroup().name().toLowerCase(), + fieldSpecModalityType.getName(), + srcField.toBase64()); + } + rawDataNode.set( + modalityParamName, OBJECT_MAPPER.createObjectNode().put("url", fieldValue)); + rawDataNodes.add(rawDataNode); } - rawDataNode.set(modalityParamName, OBJECT_MAPPER.createObjectNode().put("url", fieldValue)); - - return rawDataNode; + return rawDataNodes; } private String getModalityParamName(ModalityType inputType) { diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java index b5e4689e632..fd32907012f 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java @@ -17,40 +17,58 @@ package org.apache.seatunnel.transform.embedding; -import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.ObjectMapper; import org.apache.seatunnel.shade.com.fasterxml.jackson.databind.node.ObjectNode; -import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec; +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcField; +import org.apache.seatunnel.transform.nlpmodel.embedding.VectorFieldSpec; import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType; import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.MultimodalFieldValue; import org.apache.seatunnel.transform.nlpmodel.embedding.remote.doubao.DoubaoModel; +import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.ThreadLocalRandom; public class DoubaoMultimodalModelTest { - private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + private DoubaoModel model; - @Test - void testMultimodalBodyWithText() throws IOException { - DoubaoModel model = + @BeforeEach + void setUp() { + this.model = new DoubaoModel( "test-api-key", "doubao-embedding-vision", "https://ark.cn-beijing.volces.com/api/v3/embeddings", 1); + } + + @AfterEach + void tearDown() throws IOException { + if (model != null) { + model.close(); + } + } + @Test + void testMultimodalBodyWithText() { Map.Entry textFieldEntry = - new java.util.AbstractMap.SimpleEntry<>("text_vector", "Hello world"); - FieldSpec fieldSpec = new FieldSpec(textFieldEntry); + new java.util.AbstractMap.SimpleEntry<>("text_vector", "text_field"); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(textFieldEntry); MultimodalFieldValue multimodalFieldValue = - new MultimodalFieldValue(fieldSpec, "Hello world"); + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), "Hello world"))); ObjectNode result = model.multimodalBody(multimodalFieldValue); @@ -63,8 +81,6 @@ void testMultimodalBodyWithText() throws IOException { Assertions.assertEquals("Hello world", inputNode.get("text").asText()); Assertions.assertFalse(inputNode.has("image_url")); Assertions.assertFalse(inputNode.has("video_url")); - - model.close(); } /** @@ -73,26 +89,21 @@ void testMultimodalBodyWithText() throws IOException { * "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" } }] } */ @Test - void testMultimodalBodyWithImage() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - + void testMultimodalBodyWithImage() { Map imageFieldConfig = new HashMap<>(); imageFieldConfig.put("field", "image_field"); imageFieldConfig.put("modality", "jpeg"); imageFieldConfig.put("format", "url"); - Map.Entry imageFieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector", imageFieldConfig); - FieldSpec fieldSpec = new FieldSpec(imageFieldEntry); + + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(imageFieldEntry); MultimodalFieldValue multimodalFieldValue = new MultimodalFieldValue( - fieldSpec, - "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg"); + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), + "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg"))); ObjectNode result = model.multimodalBody(multimodalFieldValue); @@ -110,8 +121,6 @@ void testMultimodalBodyWithImage() throws IOException { inputNode.get("image_url").get("url").asText()); Assertions.assertFalse(inputNode.has("text")); Assertions.assertFalse(inputNode.has("video_url")); - - model.close(); } /** @@ -119,24 +128,21 @@ void testMultimodalBodyWithImage() throws IOException { * "video_url", "video_url" : { "url" : "https://example.com/video.mp4" } } ] } */ @Test - void testMultimodalBodyWithVideo() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - + void testMultimodalBodyWithVideo() { Map videoFieldConfig = new HashMap<>(); videoFieldConfig.put("field", "video_field"); videoFieldConfig.put("modality", "mP4"); videoFieldConfig.put("format", "url"); - Map.Entry videoFieldEntry = new java.util.AbstractMap.SimpleEntry<>("video_vector", videoFieldConfig); - FieldSpec fieldSpec = new FieldSpec(videoFieldEntry); + + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(videoFieldEntry); MultimodalFieldValue multimodalFieldValue = - new MultimodalFieldValue(fieldSpec, "https://example.com/video.mp4"); + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), + "https://example.com/video.mp4"))); ObjectNode result = model.multimodalBody(multimodalFieldValue); @@ -151,8 +157,6 @@ void testMultimodalBodyWithVideo() throws IOException { "https://example.com/video.mp4", inputNode.get("video_url").get("url").asText()); Assertions.assertFalse(inputNode.has("text")); Assertions.assertFalse(inputNode.has("image_url")); - - model.close(); } /** @@ -160,39 +164,131 @@ void testMultimodalBodyWithVideo() throws IOException { * f"data:image/;base64,{base64_image}" } } */ @Test - void testMultimodalBodyWithBinaryImage() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision-250615", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - + void testMultimodalBodyWithBinaryImage() { Map binaryImageFieldConfig = new HashMap<>(); binaryImageFieldConfig.put("field", "binary_image_field"); binaryImageFieldConfig.put("modality", "png"); binaryImageFieldConfig.put("format", "binary"); - Map.Entry binaryImageFieldEntry = new java.util.AbstractMap.SimpleEntry<>( "binary_image_vector", binaryImageFieldConfig); - FieldSpec fieldSpec = new FieldSpec(binaryImageFieldEntry); - byte[] mockImageData = "mock-image-data".getBytes(); + + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(binaryImageFieldEntry); MultimodalFieldValue multimodalFieldValue = - new MultimodalFieldValue(fieldSpec, mockImageData); + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), mockImageData))); ObjectNode result = model.multimodalBody(multimodalFieldValue); - - Assertions.assertEquals("doubao-embedding-vision-250615", result.get("model").asText()); + Assertions.assertEquals("doubao-embedding-vision", result.get("model").asText()); Assertions.assertEquals("float", result.get("encoding_format").asText()); Assertions.assertEquals(1, result.get("input").size()); ObjectNode inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("image_url", inputNode.get("type").asText()); Assertions.assertTrue(inputNode.has("image_url")); + } - model.close(); + /** + * { "model": "doubao-embedding-vision", "encoding_format": "float", "input": [ { "type": + * "text", "text": "Hello world 1" }, { "type": "text", "text": "Hello world 2" } ] } + */ + @Test + void testMultimodalBodyWithSameModalityList() { + Map.Entry vectorFieldEntry = + new java.util.AbstractMap.SimpleEntry<>( + "same_multimodal_vector", Arrays.asList("text_field_1", "text_field_2")); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(vectorFieldEntry); + MultimodalFieldValue multimodalFieldValue = + new MultimodalFieldValue( + Arrays.asList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), "Hello world 1"), + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(1), + "Hello world 2"))); + ObjectNode result = model.multimodalBody(multimodalFieldValue); + Assertions.assertEquals("doubao-embedding-vision", result.get("model").asText()); + Assertions.assertEquals("float", result.get("encoding_format").asText()); + Assertions.assertEquals(2, result.get("input").size()); + ObjectNode inputNode = (ObjectNode) result.get("input").get(0); + Assertions.assertEquals("text", inputNode.get("type").asText()); + Assertions.assertEquals("Hello world 1", inputNode.get("text").asText()); + Assertions.assertFalse(inputNode.has("image_url")); + Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(1); + Assertions.assertEquals("text", inputNode.get("type").asText()); + Assertions.assertEquals("Hello world 2", inputNode.get("text").asText()); + Assertions.assertFalse(inputNode.has("image_url")); + Assertions.assertFalse(inputNode.has("video_url")); + } + + /** + * { "model": "doubao-embedding-vision", "encoding_format": "float", "input": [ { "type": + * "text", "text": "Hello world" }, { "type": "image_url", "image_url": { "url": + * "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" } }, { "type": + * "video_url", "video_url": { "url": "https://example.com/video.mp4" } } ] } + */ + @Test + void testMultimodalBodyWithDifferentModalityList() { + Object textFieldConfig = "text_field"; + if (ThreadLocalRandom.current().nextBoolean()) { + Map textFieldConfigMap = new HashMap<>(); + textFieldConfigMap.put("field", "text_field"); + textFieldConfigMap.put("modality", "text"); + textFieldConfigMap.put("format", "text"); + textFieldConfig = textFieldConfigMap; + } + Map imageFieldConfig = new HashMap<>(); + imageFieldConfig.put("field", "image_field"); + imageFieldConfig.put("modality", "jpeg"); + imageFieldConfig.put("format", "url"); + Map videoFieldConfig = new HashMap<>(); + videoFieldConfig.put("field", "video_field"); + videoFieldConfig.put("modality", "mp4"); + videoFieldConfig.put("format", "url"); + Map.Entry vectorFieldEntry = + new java.util.AbstractMap.SimpleEntry<>( + "different_multimodal_vector", + Arrays.asList(textFieldConfig, imageFieldConfig, videoFieldConfig)); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(vectorFieldEntry); + MultimodalFieldValue multimodalFieldValue = + new MultimodalFieldValue( + Arrays.asList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), "Hello world"), + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(1), + "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg"), + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(2), + "https://example.com/video.mp4"))); + ObjectNode result = model.multimodalBody(multimodalFieldValue); + Assertions.assertEquals("doubao-embedding-vision", result.get("model").asText()); + Assertions.assertEquals("float", result.get("encoding_format").asText()); + Assertions.assertEquals(3, result.get("input").size()); + ObjectNode inputNode = (ObjectNode) result.get("input").get(0); + Assertions.assertEquals("text", inputNode.get("type").asText()); + Assertions.assertEquals("Hello world", inputNode.get("text").asText()); + Assertions.assertFalse(inputNode.has("image_url")); + Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(1); + Assertions.assertEquals("image_url", inputNode.get("type").asText()); + Assertions.assertTrue(inputNode.has("image_url")); + Assertions.assertEquals( + "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg", + inputNode.get("image_url").get("url").asText()); + Assertions.assertFalse(inputNode.has("text")); + Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(2); + Assertions.assertEquals("video_url", inputNode.get("type").asText()); + Assertions.assertTrue(inputNode.has("video_url")); + Assertions.assertEquals( + "https://example.com/video.mp4", inputNode.get("video_url").get("url").asText()); + Assertions.assertFalse(inputNode.has("text")); + Assertions.assertFalse(inputNode.has("image_url")); } @Test @@ -241,27 +337,25 @@ void testParseMultimodalVectorResponseSuccess() throws IOException { } @Test - void testUrlAutoDetectModality() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - + void testUrlAutoDetectModality() { Map fieldConfig = new HashMap<>(); fieldConfig.put("field", "image_field"); fieldConfig.put("format", "url"); fieldConfig.put("modality", "png"); - Map.Entry fieldEntry = + Map.Entry imageFieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector", fieldConfig); - FieldSpec fieldSpec = new FieldSpec(fieldEntry); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(imageFieldEntry); MultimodalFieldValue multimodalFieldValue = - new MultimodalFieldValue(fieldSpec, "https://example.com/photo.jpg"); + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), + "https://example.com/photo.jpg"))); Assertions.assertEquals( - ModalityType.JPEG, multimodalFieldValue.getFieldSpec().getModalityType()); + ModalityType.JPEG, + multimodalFieldValue.getSrcFields().get(0).getFieldSpec().getModalityType()); ObjectNode result = model.multimodalBody(multimodalFieldValue); ObjectNode inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("image_url", inputNode.get("type").asText()); @@ -269,46 +363,45 @@ void testUrlAutoDetectModality() throws IOException { Map fieldConfig2 = new HashMap<>(); fieldConfig2.put("field", "image_field"); fieldConfig2.put("format", "url"); - fieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector", fieldConfig2); - fieldSpec = new FieldSpec(fieldEntry); - - multimodalFieldValue = new MultimodalFieldValue(fieldSpec, "https://example.com/photo.jpg"); + imageFieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector", fieldConfig2); + vectorFieldSpec = new VectorFieldSpec(imageFieldEntry); + multimodalFieldValue = + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), + "https://example.com/photo.jpg"))); Assertions.assertEquals( - ModalityType.JPEG, multimodalFieldValue.getFieldSpec().getModalityType()); + ModalityType.JPEG, + multimodalFieldValue.getSrcFields().get(0).getFieldSpec().getModalityType()); result = model.multimodalBody(multimodalFieldValue); inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("image_url", inputNode.get("type").asText()); - - model.close(); } @Test - void testBinaryAutoDetectModality() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - + void testBinaryAutoDetectModality() { Map fieldConfig = new HashMap<>(); fieldConfig.put("field", "image_field"); fieldConfig.put("format", "binary"); fieldConfig.put("modality", "png"); - Map.Entry fieldEntry = + Map.Entry imageFieldEntry = new java.util.AbstractMap.SimpleEntry<>("image_vector", fieldConfig); - FieldSpec fieldSpec = new FieldSpec(fieldEntry); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(imageFieldEntry); MultimodalFieldValue multimodalFieldValue = - new MultimodalFieldValue(fieldSpec, "https://example.com/photo.jpg"); + new MultimodalFieldValue( + Collections.singletonList( + new SrcField( + vectorFieldSpec.getSrcFieldSpecs().get(0), + "https://example.com/photo.jpg"))); Assertions.assertEquals( - ModalityType.PNG, multimodalFieldValue.getFieldSpec().getModalityType()); + ModalityType.PNG, + multimodalFieldValue.getSrcFields().get(0).getFieldSpec().getModalityType()); ObjectNode result = model.multimodalBody(multimodalFieldValue); ObjectNode inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("image_url", inputNode.get("type").asText()); - - model.close(); } } diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java deleted file mode 100644 index c97372f8fe2..00000000000 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/FieldSpecTest.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.seatunnel.transform.embedding; - -import org.apache.seatunnel.transform.nlpmodel.embedding.FieldSpec; -import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType; -import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat; - -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -import java.util.AbstractMap; -import java.util.HashMap; -import java.util.Map; - -public class FieldSpecTest { - - @Test - void testMapEntryConstructorWithStringValue() { - Map.Entry entry = - new AbstractMap.SimpleEntry<>("book_intro_vector", "book_intro"); - FieldSpec fieldSpec = new FieldSpec(entry); - Assertions.assertEquals("book_intro", fieldSpec.getFieldName()); - Assertions.assertEquals(ModalityType.TEXT, fieldSpec.getModalityType()); - Assertions.assertEquals(PayloadFormat.TEXT, fieldSpec.getPayloadFormat()); - Assertions.assertFalse(fieldSpec.isMultimodalField()); - Assertions.assertFalse(fieldSpec.isBinary()); - } - - @Test - void testMapEntryConstructorWithStringValueTrimming() { - Map.Entry entry = - new AbstractMap.SimpleEntry<>("book_intro_vector", " book_intro "); - FieldSpec fieldSpec = new FieldSpec(entry); - Assertions.assertEquals("book_intro", fieldSpec.getFieldName()); - Assertions.assertEquals(ModalityType.TEXT, fieldSpec.getModalityType()); - Assertions.assertEquals(PayloadFormat.TEXT, fieldSpec.getPayloadFormat()); - } - - @Test - void testMapEntryConstructorWithNullKey() { - Map.Entry entry = new AbstractMap.SimpleEntry<>(null, "book_intro"); - IllegalArgumentException exception = - Assertions.assertThrows(IllegalArgumentException.class, () -> new FieldSpec(entry)); - Assertions.assertTrue(exception.getMessage().contains("Field spec cannot be null")); - } - - @Test - void testMapEntryConstructorWithEmpty() { - Map.Entry entry = new AbstractMap.SimpleEntry<>("book_intro_vector", null); - IllegalArgumentException exception = - Assertions.assertThrows(IllegalArgumentException.class, () -> new FieldSpec(entry)); - Assertions.assertTrue( - exception.getMessage().contains("Invalid field spec for output field")); - - Map.Entry entry2 = new AbstractMap.SimpleEntry<>("book_intro_vector", ""); - exception = - Assertions.assertThrows( - IllegalArgumentException.class, () -> new FieldSpec(entry2)); - Assertions.assertTrue( - exception.getMessage().contains("Invalid field spec for output field")); - } - - @Test - void testMapEntryConstructorWithMapValue() { - - Map fieldConfig = new HashMap<>(); - fieldConfig.put("field", "book_image"); - fieldConfig.put("modality", "jpeg"); - fieldConfig.put("format", "binary"); - - Map.Entry entry = new AbstractMap.SimpleEntry<>("book_field", fieldConfig); - - FieldSpec fieldSpec = new FieldSpec(entry); - - Assertions.assertEquals("book_image", fieldSpec.getFieldName()); - Assertions.assertEquals(ModalityType.JPEG, fieldSpec.getModalityType()); - Assertions.assertEquals(PayloadFormat.BINARY, fieldSpec.getPayloadFormat()); - Assertions.assertTrue(fieldSpec.isMultimodalField()); - Assertions.assertTrue(fieldSpec.isBinary()); - } - - @Test - void testMapEntryConstructorWithMapValueNoModality() { - Map fieldConfig = new HashMap<>(); - fieldConfig.put("field", "book_intro"); - fieldConfig.put("modality", "text"); - fieldConfig.put("format", "text"); - - Map.Entry entry = new AbstractMap.SimpleEntry<>("book_field", fieldConfig); - - FieldSpec fieldSpec = new FieldSpec(entry); - - Assertions.assertEquals("book_intro", fieldSpec.getFieldName()); - Assertions.assertEquals(ModalityType.TEXT, fieldSpec.getModalityType()); - Assertions.assertEquals(PayloadFormat.TEXT, fieldSpec.getPayloadFormat()); - Assertions.assertFalse(fieldSpec.isMultimodalField()); - } -} diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java new file mode 100644 index 00000000000..27f46d993cb --- /dev/null +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.seatunnel.transform.embedding; + +import org.apache.seatunnel.transform.nlpmodel.embedding.SrcFieldSpec; +import org.apache.seatunnel.transform.nlpmodel.embedding.VectorFieldSpec; +import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.ModalityType; +import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class VectorFieldSpecTest { + + @Test + void testMapEntryConstructorWithStringValue() { + Map.Entry entry = + new AbstractMap.SimpleEntry<>("book_intro_vector", "book_intro"); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + Assertions.assertEquals("book_intro_vector", vectorFieldSpec.getFieldName()); + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + Assertions.assertEquals("book_intro", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(vectorFieldSpec.isMultimodalField()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + } + + @Test + void testMapEntryConstructorWithStringValueTrimming() { + Map.Entry entry = + new AbstractMap.SimpleEntry<>("book_intro_vector", " book_intro "); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + Assertions.assertEquals("book_intro", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + } + + @Test + void testMapEntryConstructorWithNullKey() { + Map.Entry entry = new AbstractMap.SimpleEntry<>(null, "book_intro"); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> new VectorFieldSpec(entry)); + Assertions.assertTrue( + exception.getMessage().contains("Field config name be null or empty")); + } + + @Test + void testMapEntryConstructorWithEmpty() { + Map.Entry entry = new AbstractMap.SimpleEntry<>("book_intro_vector", null); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> new VectorFieldSpec(entry)); + Assertions.assertTrue(exception.getMessage().contains("Field config value cannot be null")); + + Map.Entry entry2 = new AbstractMap.SimpleEntry<>("book_intro_vector", ""); + exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> new VectorFieldSpec(entry2)); + Assertions.assertTrue( + exception.getMessage().contains("Invalid field spec for output field")); + } + + @Test + void testMapEntryConstructorWithMapValue() { + Map fieldConfig = new HashMap<>(); + fieldConfig.put("field", "book_image"); + fieldConfig.put("modality", "jpeg"); + fieldConfig.put("format", "binary"); + + Map.Entry entry = new AbstractMap.SimpleEntry<>("book_field", fieldConfig); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + + Assertions.assertEquals("book_image", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.JPEG, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.BINARY, srcFieldSpec.getPayloadFormat()); + Assertions.assertTrue(vectorFieldSpec.isMultimodalField()); + Assertions.assertTrue(srcFieldSpec.isBinary()); + } + + @Test + void testMapEntryConstructorWithMapValueNoModality() { + Map fieldConfig = new HashMap<>(); + fieldConfig.put("field", "book_intro"); + fieldConfig.put("modality", "text"); + fieldConfig.put("format", "text"); + + Map.Entry entry = new AbstractMap.SimpleEntry<>("book_field", fieldConfig); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + + Assertions.assertEquals("book_intro", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(vectorFieldSpec.isMultimodalField()); + } + + @Test + void testMapEntryConstructorWithInvalidListValue() { + List textFieldConfig = Arrays.asList("text_field_1", "text_field_2"); + Map imageFieldConfig = new HashMap<>(); + imageFieldConfig.put("field", "image_field"); + imageFieldConfig.put("modality", "jpeg"); + imageFieldConfig.put("format", "url"); + + Map.Entry entry = + new AbstractMap.SimpleEntry<>( + "vector_field", Arrays.asList(textFieldConfig, imageFieldConfig)); + IllegalArgumentException exception = + Assertions.assertThrows( + IllegalArgumentException.class, () -> new VectorFieldSpec(entry)); + Assertions.assertTrue( + exception.getMessage().contains("Invalid field spec for output field")); + } + + @Test + void testMapEntryConstructorWithSameModalityListValue() { + Map.Entry entry = + new AbstractMap.SimpleEntry<>( + "vector_field", Arrays.asList("text_field_1", "text_field_2")); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + Assertions.assertEquals("vector_field", vectorFieldSpec.getFieldName()); + Assertions.assertTrue(vectorFieldSpec.isMultimodalField()); + + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + Assertions.assertEquals("text_field_1", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + + srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(1); + Assertions.assertEquals("text_field_2", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + } + + @Test + void testMapEntryConstructorWithDifferentModalityListValue() { + Map imageFieldConfig = new HashMap<>(); + imageFieldConfig.put("field", "image_field"); + imageFieldConfig.put("modality", "jpeg"); + imageFieldConfig.put("format", "url"); + + Map videoFieldConfig = new HashMap<>(); + videoFieldConfig.put("field", "video_field"); + videoFieldConfig.put("modality", "mp4"); + videoFieldConfig.put("format", "url"); + + Map.Entry entry = + new AbstractMap.SimpleEntry<>( + "vector_field", + Arrays.asList("text_field", imageFieldConfig, videoFieldConfig)); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(entry); + Assertions.assertEquals("vector_field", vectorFieldSpec.getFieldName()); + Assertions.assertTrue(vectorFieldSpec.isMultimodalField()); + + SrcFieldSpec srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(0); + Assertions.assertEquals("text_field", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.TEXT, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.TEXT, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + + srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(1); + Assertions.assertEquals("image_field", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.JPEG, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.URL, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + + srcFieldSpec = vectorFieldSpec.getSrcFieldSpecs().get(2); + Assertions.assertEquals("video_field", srcFieldSpec.getFieldName()); + Assertions.assertEquals(ModalityType.MP4, srcFieldSpec.getModalityType()); + Assertions.assertEquals(PayloadFormat.URL, srcFieldSpec.getPayloadFormat()); + Assertions.assertFalse(srcFieldSpec.isBinary()); + } +} From 982c070427d24c9826cfbf1282ee442c9ee5b1b9 Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Wed, 29 Oct 2025 00:34:42 +0800 Subject: [PATCH 2/6] fix: fix VectorFieldSpecTest test case --- .../seatunnel/transform/embedding/VectorFieldSpecTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java index 27f46d993cb..42677ead467 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/VectorFieldSpecTest.java @@ -65,7 +65,7 @@ void testMapEntryConstructorWithNullKey() { Assertions.assertThrows( IllegalArgumentException.class, () -> new VectorFieldSpec(entry)); Assertions.assertTrue( - exception.getMessage().contains("Field config name be null or empty")); + exception.getMessage().contains("Field config name cannot be null or empty")); } @Test From 6495a12cb575fdffbd70755a5b66ac0d45a6de7e Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Wed, 29 Oct 2025 00:57:52 +0800 Subject: [PATCH 3/6] chore: add spotless toggleOffOn --- pom.xml | 2 + .../embedding/DoubaoMultimodalModelTest.java | 89 ++++++++++++++++--- 2 files changed, 78 insertions(+), 13 deletions(-) diff --git a/pom.xml b/pom.xml index 11df2d57c7e..c9a3b4d80f8 100644 --- a/pom.xml +++ b/pom.xml @@ -916,6 +916,8 @@ + + org.apache.seatunnel.shade,org.apache.seatunnel,org.apache,org,,javax,java,\# diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java index fd32907012f..083c45f85e7 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java @@ -84,9 +84,20 @@ void testMultimodalBodyWithText() { } /** - * { "model" : "doubao-embedding-vision", "encoding_format" : "float", "input" : [ { "type" : - * "image_url", "image_url" : { "url" : - * "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" } }] } + * spotless:off + * { + * "model": "doubao-embedding-vision", + * "encoding_format": "float", + * "input": [ + * { + * "type": "image_url", + * "image_url": { + * "url": "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" + * } + * } + * ] + * } + * spotless:on */ @Test void testMultimodalBodyWithImage() { @@ -124,8 +135,20 @@ void testMultimodalBodyWithImage() { } /** - * { "model" : "doubao-embedding-vision", "encoding_format" : "float", "input" : [ { "type" : - * "video_url", "video_url" : { "url" : "https://example.com/video.mp4" } } ] } + * spotless:off + * { + * "model": "doubao-embedding-vision", + * "encoding_format": "float", + * "input": [ + * { + * "type": "video_url", + * "video_url": { + * "url": "https://example.com/video.mp4" + * } + * } + * ] + * } + * spotless:on */ @Test void testMultimodalBodyWithVideo() { @@ -160,8 +183,14 @@ void testMultimodalBodyWithVideo() { } /** - * { "type": "image_url", "image_url": { "url": - * f"data:image/;base64,{base64_image}" } } + * spotless:off + * { + * "type": "image_url", + * "image_url": { + * "url": f"data:image/;base64,{base64_image}" + * } + * } + * spotless:on */ @Test void testMultimodalBodyWithBinaryImage() { @@ -192,8 +221,22 @@ void testMultimodalBodyWithBinaryImage() { } /** - * { "model": "doubao-embedding-vision", "encoding_format": "float", "input": [ { "type": - * "text", "text": "Hello world 1" }, { "type": "text", "text": "Hello world 2" } ] } + * spotless:off + * { + * "model": "doubao-embedding-vision", + * "encoding_format": "float", + * "input": [ + * { + * "type": "text", + * "text": "Hello world 1" + * }, + * { + * "type": "text", + * "text": "Hello world 2" + * } + * ] + * } + * spotless:on */ @Test void testMultimodalBodyWithSameModalityList() { @@ -226,10 +269,30 @@ void testMultimodalBodyWithSameModalityList() { } /** - * { "model": "doubao-embedding-vision", "encoding_format": "float", "input": [ { "type": - * "text", "text": "Hello world" }, { "type": "image_url", "image_url": { "url": - * "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" } }, { "type": - * "video_url", "video_url": { "url": "https://example.com/video.mp4" } } ] } + * spotless:off + * { + * "model": "doubao-embedding-vision", + * "encoding_format": "float", + * "input": [ + * { + * "type": "text", + * "text": "Hello world" + * }, + * { + * "type": "image_url", + * "image_url": { + * "url": "https://ck-test.tos-cn-beijing.volces.com/vlm/pexels-photo-27163466.jpeg" + * } + * }, + * { + * "type": "video_url", + * "video_url": { + * "url": "https://example.com/video.mp4" + * } + * } + * ] + * } + * spotless:on */ @Test void testMultimodalBodyWithDifferentModalityList() { From d9fcc530e4ee34fb57b99b935c80cbe5de63b127 Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Thu, 30 Oct 2025 01:58:19 +0800 Subject: [PATCH 4/6] chore: code format --- .../seatunnel/transform/nlpmodel/embedding/SrcField.java | 4 +--- .../embedding/multimodal/MultimodalFieldValue.java | 1 - .../nlpmodel/embedding/remote/doubao/DoubaoModel.java | 1 + .../transform/embedding/DoubaoMultimodalModelTest.java | 8 ++++++++ 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java index c1fbcdd2b1d..470ae06e869 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/SrcField.java @@ -17,8 +17,6 @@ package org.apache.seatunnel.transform.nlpmodel.embedding; -import org.apache.seatunnel.transform.nlpmodel.embedding.multimodal.PayloadFormat; - import lombok.Data; import java.io.Serializable; @@ -39,7 +37,7 @@ public SrcField(SrcFieldSpec spec, Object value) { } public String toBase64() { - if (fieldSpec == null || !PayloadFormat.BINARY.equals(fieldSpec.getPayloadFormat())) { + if (fieldSpec == null || !fieldSpec.isBinary()) { throw new IllegalArgumentException("Payload format must be binary"); } if (fieldValue == null) { diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java index c4d748db5ab..630b310caa9 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/multimodal/MultimodalFieldValue.java @@ -48,7 +48,6 @@ public MultimodalFieldValue(List srcFields) { * analyze the value suffix to determine modality type */ private ModalityType determineModalityType(SrcFieldSpec fieldSpec, Object fieldValue) { - if (fieldSpec.isBinary()) { return fieldSpec.getModalityType(); } diff --git a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java index 98493c5a007..161dea4e1bb 100644 --- a/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java +++ b/seatunnel-transforms-v2/src/main/java/org/apache/seatunnel/transform/nlpmodel/embedding/remote/doubao/DoubaoModel.java @@ -239,6 +239,7 @@ protected List inputRawData(MultimodalFieldValue field) { rawDataNodes.add(rawDataNode); continue; } + if (srcField.getFieldSpec().isBinary()) { fieldValue = String.format( diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java index 083c45f85e7..2edfe21f5be 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java @@ -252,15 +252,18 @@ void testMultimodalBodyWithSameModalityList() { new SrcField( vectorFieldSpec.getSrcFieldSpecs().get(1), "Hello world 2"))); + ObjectNode result = model.multimodalBody(multimodalFieldValue); Assertions.assertEquals("doubao-embedding-vision", result.get("model").asText()); Assertions.assertEquals("float", result.get("encoding_format").asText()); Assertions.assertEquals(2, result.get("input").size()); + ObjectNode inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("text", inputNode.get("type").asText()); Assertions.assertEquals("Hello world 1", inputNode.get("text").asText()); Assertions.assertFalse(inputNode.has("image_url")); Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(1); Assertions.assertEquals("text", inputNode.get("type").asText()); Assertions.assertEquals("Hello world 2", inputNode.get("text").asText()); @@ -316,6 +319,7 @@ void testMultimodalBodyWithDifferentModalityList() { new java.util.AbstractMap.SimpleEntry<>( "different_multimodal_vector", Arrays.asList(textFieldConfig, imageFieldConfig, videoFieldConfig)); + VectorFieldSpec vectorFieldSpec = new VectorFieldSpec(vectorFieldEntry); MultimodalFieldValue multimodalFieldValue = new MultimodalFieldValue( @@ -328,15 +332,18 @@ void testMultimodalBodyWithDifferentModalityList() { new SrcField( vectorFieldSpec.getSrcFieldSpecs().get(2), "https://example.com/video.mp4"))); + ObjectNode result = model.multimodalBody(multimodalFieldValue); Assertions.assertEquals("doubao-embedding-vision", result.get("model").asText()); Assertions.assertEquals("float", result.get("encoding_format").asText()); Assertions.assertEquals(3, result.get("input").size()); + ObjectNode inputNode = (ObjectNode) result.get("input").get(0); Assertions.assertEquals("text", inputNode.get("type").asText()); Assertions.assertEquals("Hello world", inputNode.get("text").asText()); Assertions.assertFalse(inputNode.has("image_url")); Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(1); Assertions.assertEquals("image_url", inputNode.get("type").asText()); Assertions.assertTrue(inputNode.has("image_url")); @@ -345,6 +352,7 @@ void testMultimodalBodyWithDifferentModalityList() { inputNode.get("image_url").get("url").asText()); Assertions.assertFalse(inputNode.has("text")); Assertions.assertFalse(inputNode.has("video_url")); + inputNode = (ObjectNode) result.get("input").get(2); Assertions.assertEquals("video_url", inputNode.get("type").asText()); Assertions.assertTrue(inputNode.has("video_url")); From f2ee8635b19b498cb7ba16098721b61b57100327 Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Fri, 31 Oct 2025 02:06:33 +0800 Subject: [PATCH 5/6] chore:remove deduplicated code --- .../transform/embedding/DoubaoMultimodalModelTest.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java index 2edfe21f5be..86ce1085543 100644 --- a/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java +++ b/seatunnel-transforms-v2/src/test/java/org/apache/seatunnel/transform/embedding/DoubaoMultimodalModelTest.java @@ -364,13 +364,6 @@ void testMultimodalBodyWithDifferentModalityList() { @Test void testParseMultimodalVectorResponseSuccess() throws IOException { - DoubaoModel model = - new DoubaoModel( - "test-api-key", - "doubao-embedding-vision", - "https://ark.cn-beijing.volces.com/api/v3/embeddings", - 1); - String successResponse = "{\n" + " \"created\": 1743575029,\n" @@ -403,8 +396,6 @@ void testParseMultimodalVectorResponseSuccess() throws IOException { Assertions.assertEquals(-0.318359375f, result.get(2), 0.0001f); Assertions.assertEquals(0.255859375f, result.get(3), 0.0001f); Assertions.assertEquals(1.5f, result.get(4), 0.0001f); - - model.close(); } @Test From bdbd0122ea88cc6d931ececaf5607ef9ffd52ec5 Mon Sep 17 00:00:00 2001 From: "chenlantian.michael" Date: Tue, 4 Nov 2025 22:43:33 +0800 Subject: [PATCH 6/6] chore: modify comment --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index c9a3b4d80f8..cc909b8d84b 100644 --- a/pom.xml +++ b/pom.xml @@ -916,7 +916,7 @@ - + org.apache.seatunnel.shade,org.apache.seatunnel,org.apache,org,,javax,java,\#