1+ /*
2+ * Licensed to the Apache Software Foundation (ASF) under one or more
3+ * contributor license agreements. See the NOTICE file distributed with
4+ * this work for additional information regarding copyright ownership.
5+ * The ASF licenses this file to You under the Apache License, Version 2.0
6+ * (the "License"); you may not use this file except in compliance with
7+ * the License. You may obtain a copy of the License at
8+ *
9+ * http://www.apache.org/licenses/LICENSE-2.0
10+ *
11+ * Unless required by applicable law or agreed to in writing, software
12+ * distributed under the License is distributed on an "AS IS" BASIS,
13+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+ * See the License for the specific language governing permissions and
15+ * limitations under the License.
16+ */
17+
18+ package org .apache .spark .sql .catalyst .expressions
19+
20+ import org .apache .spark .sql .types ._
21+ import org .apache .spark .unsafe .PlatformDependent
22+ import org .apache .spark .unsafe .array .ByteArrayMethods
23+
24+ /** Write a column into an UnsafeRow */
25+ private abstract class UnsafeColumnWriter [T ] {
26+ /**
27+ * Write a value into an UnsafeRow.
28+ *
29+ * @param value the value to write
30+ * @param columnNumber what column to write it to
31+ * @param row a pointer to the unsafe row
32+ * @param baseObject
33+ * @param baseOffset
34+ * @param appendCursor the offset from the start of the unsafe row to the end of the row;
35+ * used for calculating where variable-length data should be written
36+ * @return the number of variable-length bytes written
37+ */
38+ def write (
39+ value : T ,
40+ columnNumber : Int ,
41+ row : UnsafeRow ,
42+ baseObject : Object ,
43+ baseOffset : Long ,
44+ appendCursor : Int ): Int
45+
46+ /**
47+ * Return the number of bytes that are needed to write this variable-length value.
48+ */
49+ def getSize (value : T ): Int
50+ }
51+
52+ private object UnsafeColumnWriter {
53+ def forType (dataType : DataType ): UnsafeColumnWriter [_] = {
54+ dataType match {
55+ case IntegerType => IntUnsafeColumnWriter
56+ case LongType => LongUnsafeColumnWriter
57+ case StringType => StringUnsafeColumnWriter
58+ case _ => throw new UnsupportedOperationException ()
59+ }
60+ }
61+ }
62+
63+ private class StringUnsafeColumnWriter private () extends UnsafeColumnWriter [UTF8String ] {
64+ def getSize (value : UTF8String ): Int = {
65+ // round to nearest word
66+ val numBytes = value.getBytes.length
67+ 8 + ByteArrayMethods .roundNumberOfBytesToNearestWord(numBytes)
68+ }
69+
70+ override def write (
71+ value : UTF8String ,
72+ columnNumber : Int ,
73+ row : UnsafeRow ,
74+ baseObject : Object ,
75+ baseOffset : Long ,
76+ appendCursor : Int ): Int = {
77+ val numBytes = value.getBytes.length
78+ PlatformDependent .UNSAFE .putLong(baseObject, baseOffset + appendCursor, numBytes)
79+ PlatformDependent .copyMemory(
80+ value.getBytes,
81+ PlatformDependent .BYTE_ARRAY_OFFSET ,
82+ baseObject,
83+ baseOffset + appendCursor + 8 ,
84+ numBytes
85+ )
86+ row.setLong(columnNumber, appendCursor)
87+ 8 + ((numBytes / 8 ) + (if (numBytes % 8 == 0 ) 0 else 1 )) * 8
88+ }
89+ }
90+ private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
91+
92+ private abstract class PrimitiveUnsafeColumnWriter [T ] extends UnsafeColumnWriter [T ] {
93+ def getSize (value : T ): Int = 0
94+ }
95+
96+ private class IntUnsafeColumnWriter private () extends PrimitiveUnsafeColumnWriter [Int ] {
97+ override def write (
98+ value : Int ,
99+ columnNumber : Int ,
100+ row : UnsafeRow ,
101+ baseObject : Object ,
102+ baseOffset : Long ,
103+ appendCursor : Int ): Int = {
104+ row.setInt(columnNumber, value)
105+ 0
106+ }
107+ }
108+ private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
109+
110+ private class LongUnsafeColumnWriter private () extends PrimitiveUnsafeColumnWriter [Long ] {
111+ override def write (
112+ value : Long ,
113+ columnNumber : Int ,
114+ row : UnsafeRow ,
115+ baseObject : Object ,
116+ baseOffset : Long ,
117+ appendCursor : Int ): Int = {
118+ row.setLong(columnNumber, value)
119+ 0
120+ }
121+ }
122+ private case object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
123+
124+
125+ class UnsafeRowConverter (fieldTypes : Array [DataType ]) {
126+
127+ private [this ] val writers : Array [UnsafeColumnWriter [Any ]] = {
128+ fieldTypes.map(t => UnsafeColumnWriter .forType(t).asInstanceOf [UnsafeColumnWriter [Any ]])
129+ }
130+
131+ def getSizeRequirement (row : Row ): Int = {
132+ var fieldNumber = 0
133+ var variableLengthFieldSize : Int = 0
134+ while (fieldNumber < writers.length) {
135+ if (! row.isNullAt(fieldNumber)) {
136+ variableLengthFieldSize += writers(fieldNumber).getSize(row.get(fieldNumber))
137+
138+ }
139+ fieldNumber += 1
140+ }
141+ (8 * fieldTypes.length) + UnsafeRow .calculateBitSetWidthInBytes(fieldTypes.length) + variableLengthFieldSize
142+ }
143+
144+ def writeRow (row : Row , baseObject : Object , baseOffset : Long ): Long = {
145+ val unsafeRow = new UnsafeRow ()
146+ unsafeRow.set(baseObject, baseOffset, writers.length, null ) // TODO: schema?
147+ var fieldNumber = 0
148+ var appendCursor : Int =
149+ (8 * fieldTypes.length) + UnsafeRow .calculateBitSetWidthInBytes(fieldTypes.length)
150+ while (fieldNumber < writers.length) {
151+ if (row.isNullAt(fieldNumber)) {
152+ unsafeRow.setNullAt(fieldNumber)
153+ // TODO: type-specific null value writing?
154+ } else {
155+ appendCursor += writers(fieldNumber).write(
156+ row.get(fieldNumber),
157+ fieldNumber,
158+ unsafeRow,
159+ baseObject,
160+ baseOffset,
161+ appendCursor)
162+ }
163+ fieldNumber += 1
164+ }
165+ appendCursor
166+ }
167+
168+ }
0 commit comments