@@ -25,7 +25,7 @@ import org.apache.arrow.memory.{BaseAllocator, RootAllocator}
2525import org .apache .arrow .vector ._
2626import org .apache .arrow .vector .BaseValueVector .BaseMutator
2727import org .apache .arrow .vector .schema .{ArrowFieldNode , ArrowRecordBatch }
28- import org .apache .arrow .vector .types .FloatingPointPrecision
28+ import org .apache .arrow .vector .types .{ FloatingPointPrecision , TimeUnit }
2929import org .apache .arrow .vector .types .pojo .{ArrowType , Field , Schema }
3030
3131import org .apache .spark .sql .catalyst .InternalRow
@@ -46,6 +46,9 @@ object Arrow {
4646 case DoubleType => new ArrowType .FloatingPoint (FloatingPointPrecision .DOUBLE )
4747 case ByteType => new ArrowType .Int (8 , true )
4848 case StringType => ArrowType .Utf8 .INSTANCE
49+ case BinaryType => ArrowType .Binary .INSTANCE
50+ case DateType => ArrowType .Date .INSTANCE
51+ case TimestampType => new ArrowType .Timestamp (TimeUnit .MILLISECOND )
4952 case _ => throw new UnsupportedOperationException (s " Unsupported data type: ${dataType}" )
5053 }
5154 }
@@ -57,12 +60,17 @@ object Arrow {
5760 rows : Array [InternalRow ],
5861 schema : StructType ,
5962 allocator : RootAllocator ): ArrowRecordBatch = {
60- val (fieldNodes, buffers) = schema.fields.zipWithIndex.map { case (field, ordinal) =>
63+ val fieldAndBuf = schema.fields.zipWithIndex.map { case (field, ordinal) =>
6164 internalRowToArrowBuf(rows, ordinal, field, allocator)
6265 }.unzip
66+ val fieldNodes = fieldAndBuf._1.flatten
67+ val buffers = fieldAndBuf._2.flatten
6368
64- new ArrowRecordBatch (rows.length,
65- fieldNodes.flatten.toList.asJava, buffers.flatten.toList.asJava)
69+ val recordBatch = new ArrowRecordBatch (rows.length,
70+ fieldNodes.toList.asJava, buffers.toList.asJava)
71+
72+ buffers.foreach(_.release())
73+ recordBatch
6674 }
6775
6876 /**
@@ -107,6 +115,11 @@ private[sql] trait ColumnWriter {
107115 def init (initialSize : Int ): Unit
108116 def writeNull (): Unit
109117 def write (row : InternalRow , ordinal : Int ): Unit
118+
119+ /**
120+ * Clear the column writer and return the ArrowFieldNode and ArrowBuf.
121+ * This should be called only once after all the data is written.
122+ */
110123 def finish (): (Seq [ArrowFieldNode ], Seq [ArrowBuf ])
111124}
112125
@@ -142,7 +155,7 @@ private[sql] abstract class PrimitiveColumnWriter(protected val allocator: BaseA
142155 override def finish (): (Seq [ArrowFieldNode ], Seq [ArrowBuf ]) = {
143156 valueMutator.setValueCount(count)
144157 val fieldNode = new ArrowFieldNode (count, nullCount)
145- val valueBuffers : Seq [ArrowBuf ] = valueVector.getBuffers(true ) // TODO: check the flag
158+ val valueBuffers : Seq [ArrowBuf ] = valueVector.getBuffers(true )
146159 (List (fieldNode), valueBuffers)
147160 }
148161}
@@ -239,6 +252,44 @@ private[sql] class UTF8StringColumnWriter(allocator: BaseAllocator)
239252 }
240253}
241254
255+ private [sql] class BinaryColumnWriter (allocator : BaseAllocator )
256+ extends PrimitiveColumnWriter (allocator) {
257+ override protected val valueVector : NullableVarBinaryVector
258+ = new NullableVarBinaryVector (" UTF8StringValue" , allocator)
259+ override protected val valueMutator : NullableVarBinaryVector # Mutator = valueVector.getMutator
260+
261+ override def setNull (): Unit = valueMutator.setNull(count)
262+ override def setValue (row : InternalRow , ordinal : Int ): Unit = {
263+ val bytes = row.getBinary(ordinal)
264+ valueMutator.setSafe(count, bytes, 0 , bytes.length)
265+ }
266+ }
267+
268+ private [sql] class DateColumnWriter (allocator : BaseAllocator )
269+ extends PrimitiveColumnWriter (allocator) {
270+ override protected val valueVector : NullableDateVector
271+ = new NullableDateVector (" DateValue" , allocator)
272+ override protected val valueMutator : NullableDateVector # Mutator = valueVector.getMutator
273+
274+ override protected def setNull (): Unit = valueMutator.setNull(count)
275+ override protected def setValue (row : InternalRow , ordinal : Int ): Unit = {
276+ valueMutator.setSafe(count, row.getInt(ordinal).toLong * 24 * 3600 * 1000 )
277+ }
278+ }
279+
280+ private [sql] class TimeStampColumnWriter (allocator : BaseAllocator )
281+ extends PrimitiveColumnWriter (allocator) {
282+ override protected val valueVector : NullableTimeStampVector
283+ = new NullableTimeStampVector (" TimeStampValue" , allocator)
284+ override protected val valueMutator : NullableTimeStampVector # Mutator = valueVector.getMutator
285+
286+ override protected def setNull (): Unit = valueMutator.setNull(count)
287+
288+ override protected def setValue (row : InternalRow , ordinal : Int ): Unit = {
289+ valueMutator.setSafe(count, row.getLong(ordinal) / 1000 )
290+ }
291+ }
292+
242293private [sql] object ColumnWriter {
243294 def apply (allocator : BaseAllocator , dataType : DataType ): ColumnWriter = {
244295 dataType match {
@@ -250,7 +301,10 @@ private[sql] object ColumnWriter {
250301 case DoubleType => new DoubleColumnWriter (allocator)
251302 case ByteType => new ByteColumnWriter (allocator)
252303 case StringType => new UTF8StringColumnWriter (allocator)
253- case _ => throw new UnsupportedOperationException (s " Unsupported data type: ${dataType}" )
304+ case BinaryType => new BinaryColumnWriter (allocator)
305+ case DateType => new DateColumnWriter (allocator)
306+ case TimestampType => new TimeStampColumnWriter (allocator)
307+ case _ => throw new UnsupportedOperationException (s " Unsupported data type: $dataType" )
254308 }
255309 }
256310}
0 commit comments