diff --git a/java/lance-jni/src/fragment.rs b/java/lance-jni/src/fragment.rs index c3caec67d82..fb282a9f939 100644 --- a/java/lance-jni/src/fragment.rs +++ b/java/lance-jni/src/fragment.rs @@ -29,12 +29,6 @@ use crate::{ RT, }; -#[derive(Debug, Clone)] -pub(crate) struct FragmentMergeResult { - fragment: Fragment, - schema: Schema, -} - #[derive(Debug, Clone)] pub(crate) struct FragmentUpdateResult { updated_fragment: Fragment, @@ -332,6 +326,7 @@ pub extern "system" fn Java_org_lance_Fragment_nativeMergeColumns<'a>( arrow_array_stream_addr: jlong, // memoryAddress of ArrowStream left_on: JString, // left column name to join on right_on: JString, // right column name to join on + arrow_schema_addr: jlong, // memoryAddress of arrow Schema ) -> JObject<'a> { ok_or_throw_with_return!( env, @@ -341,7 +336,8 @@ pub extern "system" fn Java_org_lance_Fragment_nativeMergeColumns<'a>( fragment_id, arrow_array_stream_addr, left_on, - right_on + right_on, + arrow_schema_addr ), JObject::null() ) @@ -355,6 +351,7 @@ fn inner_merge_column<'local>( arrow_array_stream_addr: jlong, left_on: JString, right_on: JString, + arrow_schema_addr: jlong, ) -> Result> { let (fragment_opt, max_field_id) = { let dataset = @@ -378,13 +375,13 @@ fn inner_merge_column<'local>( let left_on_str: String = left_on.extract(env)?; let right_on_str: String = right_on.extract(env)?; - let (new_frag, new_schema) = + let (new_frag, new_schema): (Fragment, Schema) = RT.block_on(fragment.merge_columns(reader, &left_on_str, &right_on_str, max_field_id))?; - let result = FragmentMergeResult { - fragment: new_frag, - schema: new_schema, - }; - result.into_java(env) + + let ffi_schema = FFI_ArrowSchema::try_from(&arrow_schema::Schema::from(&new_schema))?; + unsafe { std::ptr::write_unaligned(arrow_schema_addr as *mut FFI_ArrowSchema, ffi_schema) } + + new_frag.into_java(env) } #[no_mangle] @@ -456,27 +453,9 @@ const FRAGMENT_METADATA_CLASS: &str = "org/lance/FragmentMetadata"; const FRAGMENT_METADATA_CONSTRUCTOR_SIG: &str ="(ILjava/util/List;Ljava/lang/Long;Lorg/lance/fragment/DeletionFile;Lorg/lance/fragment/RowIdMeta;)V"; const ROW_ID_META_CLASS: &str = "org/lance/fragment/RowIdMeta"; const ROW_ID_META_CONSTRUCTOR_SIG: &str = "(Ljava/lang/String;)V"; -const FRAGMENT_MERGE_RESULT_CLASS: &str = "org/lance/fragment/FragmentMergeResult"; -const FRAGMENT_MERGE_RESULT_CONSTRUCTOR_SIG: &str = - "(Lorg/lance/FragmentMetadata;Lorg/lance/schema/LanceSchema;)V"; const FRAGMENT_UPDATE_RESULT_CLASS: &str = "org/lance/fragment/FragmentUpdateResult"; const FRAGMENT_UPDATE_RESULT_CONSTRUCTOR_SIG: &str = "(Lorg/lance/FragmentMetadata;[J)V"; -impl IntoJava for &FragmentMergeResult { - fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { - let java_fragment_meta_data = self.fragment.into_java(env)?; - let java_lance_schema = self.schema.clone().into_java(env)?; - Ok(env.new_object( - FRAGMENT_MERGE_RESULT_CLASS, - FRAGMENT_MERGE_RESULT_CONSTRUCTOR_SIG, - &[ - JValueGen::Object(&java_fragment_meta_data), - JValueGen::Object(&java_lance_schema), - ], - )?) - } -} - impl IntoJava for &FragmentUpdateResult { fn into_java<'a>(self, env: &mut JNIEnv<'a>) -> Result> { let java_updated_fragment = self.updated_fragment.into_java(env)?; diff --git a/java/src/main/java/org/lance/Fragment.java b/java/src/main/java/org/lance/Fragment.java index 812fb49548c..c893c4f3e7b 100644 --- a/java/src/main/java/org/lance/Fragment.java +++ b/java/src/main/java/org/lance/Fragment.java @@ -26,6 +26,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; import java.util.Arrays; import java.util.List; @@ -150,16 +151,28 @@ public int countRows() { * @return the fragment metadata and new schema. */ public FragmentMergeResult mergeColumns(ArrowArrayStream stream, String leftOn, String rightOn) { - return nativeMergeColumns( - dataset, fragmentMetadata.getId(), stream.memoryAddress(), leftOn, rightOn); + try (ArrowSchema ffiArrowSchema = ArrowSchema.allocateNew(dataset.allocator())) { + FragmentMetadata metadata = + nativeMergeColumns( + dataset, + fragmentMetadata.getId(), + stream.memoryAddress(), + leftOn, + rightOn, + ffiArrowSchema.memoryAddress()); + + Schema schema = Data.importSchema(dataset.allocator(), ffiArrowSchema, null); + return new FragmentMergeResult(metadata, schema); + } } - private native FragmentMergeResult nativeMergeColumns( + private native FragmentMetadata nativeMergeColumns( Dataset dataset, long fragmentId, long arrowStreamMemoryAddress, String leftOn, - String rightOn); + String rightOn, + long schemaMemoryAddress); /** * Update existed columns into this Fragment, will return the new fragment with the same diff --git a/java/src/main/java/org/lance/fragment/FragmentMergeResult.java b/java/src/main/java/org/lance/fragment/FragmentMergeResult.java index 7d22f3e1e13..0c15230ffd2 100644 --- a/java/src/main/java/org/lance/fragment/FragmentMergeResult.java +++ b/java/src/main/java/org/lance/fragment/FragmentMergeResult.java @@ -14,10 +14,10 @@ package org.lance.fragment; import org.lance.FragmentMetadata; -import org.lance.schema.LanceSchema; import com.google.common.base.MoreObjects; import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.vector.types.pojo.Schema; /** * Result of {@link org.lance.Fragment#mergeColumns(ArrowArrayStream, String, String) @@ -25,14 +25,14 @@ */ public class FragmentMergeResult { private final FragmentMetadata fragmentMetadata; - private final LanceSchema schema; + private final Schema schema; - public FragmentMergeResult(FragmentMetadata fragmentMetadata, LanceSchema schema) { + public FragmentMergeResult(FragmentMetadata fragmentMetadata, Schema schema) { this.fragmentMetadata = fragmentMetadata; this.schema = schema; } - public LanceSchema getSchema() { + public Schema getSchema() { return schema; } diff --git a/java/src/test/java/org/lance/FragmentTest.java b/java/src/test/java/org/lance/FragmentTest.java index c1e646f492b..1203720c4f1 100644 --- a/java/src/test/java/org/lance/FragmentTest.java +++ b/java/src/test/java/org/lance/FragmentTest.java @@ -299,7 +299,7 @@ void testMergeColumns(@TempDir Path tempDir) throws Exception { .operation( Merge.builder() .fragments(Collections.singletonList(mergeResult.getFragmentMetadata())) - .schema(mergeResult.getSchema().asArrowSchema()) + .schema(mergeResult.getSchema()) .build()) .readVersion(dataset.version()) .build();