@@ -32,6 +32,8 @@ import org.apache.spark.sql.types._
3232 */
3333object GenerateUnsafeProjection extends CodeGenerator [Seq [Expression ], UnsafeProjection ] {
3434
35+ case class Schema (dataType : DataType , nullable : Boolean )
36+
3537 /** Returns true iff we support this data type. */
3638 def canSupport (dataType : DataType ): Boolean = UserDefinedType .sqlType(dataType) match {
3739 case NullType => true
@@ -43,19 +45,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4345 case _ => false
4446 }
4547
46- // TODO: if the nullability of field is correct, we can use it to save null check.
4748 private def writeStructToBuffer (
4849 ctx : CodegenContext ,
4950 input : String ,
5051 index : String ,
51- fieldTypes : Seq [DataType ],
52+ schemas : Seq [Schema ],
5253 rowWriter : String ): String = {
5354 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5455 val tmpInput = ctx.freshName(" tmpInput" )
55- val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
56- ExprCode (
57- JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " ),
58- JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
56+ val fieldEvals = schemas.zipWithIndex.map { case (Schema (dt, nullable), i) =>
57+ val isNull = if (nullable) {
58+ JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " )
59+ } else {
60+ FalseLiteral
61+ }
62+ ExprCode (isNull, JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
5963 }
6064
6165 val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -70,7 +74,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7074 | // Remember the current cursor so that we can calculate how many bytes are
7175 | // written later.
7276 | final int $previousCursor = $rowWriter.cursor();
73- | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes , structRowWriter)}
77+ | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas , structRowWriter)}
7478 | $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
7579 |}
7680 """ .stripMargin
@@ -80,7 +84,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
8084 ctx : CodegenContext ,
8185 row : String ,
8286 inputs : Seq [ExprCode ],
83- inputTypes : Seq [DataType ],
87+ schemas : Seq [Schema ],
8488 rowWriter : String ,
8589 isTopLevel : Boolean = false ): String = {
8690 val resetWriter = if (isTopLevel) {
@@ -98,8 +102,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
98102 s " $rowWriter.resetRowWriter(); "
99103 }
100104
101- val writeFields = inputs.zip(inputTypes ).zipWithIndex.map {
102- case ((input, dataType), index) =>
105+ val writeFields = inputs.zip(schemas ).zipWithIndex.map {
106+ case ((input, Schema ( dataType, nullable) ), index) =>
103107 val dt = UserDefinedType .sqlType(dataType)
104108
105109 val setNull = dt match {
@@ -110,7 +114,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
110114 }
111115
112116 val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
113- if (input.isNull == FalseLiteral ) {
117+ if (! nullable ) {
114118 s """
115119 | ${input.code}
116120 | ${writeField.trim}
@@ -143,11 +147,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
143147 """ .stripMargin
144148 }
145149
146- // TODO: if the nullability of array element is correct, we can use it to save null check.
147150 private def writeArrayToBuffer (
148151 ctx : CodegenContext ,
149152 input : String ,
150153 elementType : DataType ,
154+ containsNull : Boolean ,
151155 rowWriter : String ): String = {
152156 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
153157 val tmpInput = ctx.freshName(" tmpInput" )
@@ -170,6 +174,18 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
170174
171175 val element = CodeGenerator .getValue(tmpInput, et, index)
172176
177+ val elementAssignment = if (containsNull) {
178+ s """
179+ |if ( $tmpInput.isNullAt( $index)) {
180+ | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
181+ |} else {
182+ | ${writeElement(ctx, element, index, et, arrayWriter)}
183+ |}
184+ """ .stripMargin
185+ } else {
186+ writeElement(ctx, element, index, et, arrayWriter)
187+ }
188+
173189 s """
174190 |final ArrayData $tmpInput = $input;
175191 |if ( $tmpInput instanceof UnsafeArrayData) {
@@ -179,30 +195,31 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
179195 | $arrayWriter.initialize( $numElements);
180196 |
181197 | for (int $index = 0; $index < $numElements; $index++) {
182- | if ( $tmpInput.isNullAt( $index)) {
183- | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
184- | } else {
185- | ${writeElement(ctx, element, index, et, arrayWriter)}
186- | }
198+ | $elementAssignment
187199 | }
188200 |}
189201 """ .stripMargin
190202 }
191203
192- // TODO: if the nullability of value element is correct, we can use it to save null check.
193204 private def writeMapToBuffer (
194205 ctx : CodegenContext ,
195206 input : String ,
196207 index : String ,
197208 keyType : DataType ,
198209 valueType : DataType ,
210+ valueContainsNull : Boolean ,
199211 rowWriter : String ): String = {
200212 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
201213 val tmpInput = ctx.freshName(" tmpInput" )
202214 val tmpCursor = ctx.freshName(" tmpCursor" )
203215 val previousCursor = ctx.freshName(" previousCursor" )
204216
205217 // Writes out unsafe map according to the format described in `UnsafeMapData`.
218+ val keyArray = writeArrayToBuffer(
219+ ctx, s " $tmpInput.keyArray() " , keyType, false , rowWriter)
220+ val valueArray = writeArrayToBuffer(
221+ ctx, s " $tmpInput.valueArray() " , valueType, valueContainsNull, rowWriter)
222+
206223 s """
207224 |final MapData $tmpInput = $input;
208225 |if ( $tmpInput instanceof UnsafeMapData) {
@@ -219,15 +236,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
219236 | // Remember the current cursor so that we can write numBytes of key array later.
220237 | final int $tmpCursor = $rowWriter.cursor();
221238 |
222- | ${writeArrayToBuffer(ctx, s " $tmpInput . keyArray() " , keyType, rowWriter)}
239+ | $keyArray
223240 |
224241 | // Write the numBytes of key array into the first 8 bytes.
225242 | Platform.putLong(
226243 | $rowWriter.getBuffer(),
227244 | $tmpCursor - 8,
228245 | $rowWriter.cursor() - $tmpCursor);
229246 |
230- | ${writeArrayToBuffer(ctx, s " $tmpInput . valueArray() " , valueType, rowWriter)}
247+ | $valueArray
231248 | $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
232249 |}
233250 """ .stripMargin
@@ -240,20 +257,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
240257 dt : DataType ,
241258 writer : String ): String = dt match {
242259 case t : StructType =>
243- writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
260+ writeStructToBuffer(
261+ ctx, input, index, t.map(e => Schema (e.dataType, e.nullable)), writer)
244262
245- case ArrayType (et, _ ) =>
263+ case ArrayType (et, en ) =>
246264 val previousCursor = ctx.freshName(" previousCursor" )
247265 s """
248266 |// Remember the current cursor so that we can calculate how many bytes are
249267 |// written later.
250268 |final int $previousCursor = $writer.cursor();
251- | ${writeArrayToBuffer(ctx, input, et, writer)}
269+ | ${writeArrayToBuffer(ctx, input, et, en, writer)}
252270 | $writer.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
253271 """ .stripMargin
254272
255- case MapType (kt, vt, _ ) =>
256- writeMapToBuffer(ctx, input, index, kt, vt, writer)
273+ case MapType (kt, vt, vn ) =>
274+ writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
257275
258276 case DecimalType .Fixed (precision, scale) =>
259277 s " $writer.write( $index, $input, $precision, $scale); "
@@ -268,12 +286,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
268286 expressions : Seq [Expression ],
269287 useSubexprElimination : Boolean = false ): ExprCode = {
270288 val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
271- val exprTypes = expressions.map(_ .dataType)
289+ val exprSchemas = expressions.map(e => Schema (e .dataType, e.nullable) )
272290
273- val numVarLenFields = exprTypes .count {
274- case dt if UnsafeRow .isFixedLength(dt) => false
291+ val numVarLenFields = exprSchemas .count {
292+ case Schema (dt, _) => ! UnsafeRow .isFixedLength(dt)
275293 // TODO: consider large decimal and interval type
276- case _ => true
277294 }
278295
279296 val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -284,7 +301,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
284301 val evalSubexpr = ctx.subexprFunctions.mkString(" \n " )
285302
286303 val writeExpressions = writeExpressionsToBuffer(
287- ctx, ctx.INPUT_ROW , exprEvals, exprTypes , rowWriter, isTopLevel = true )
304+ ctx, ctx.INPUT_ROW , exprEvals, exprSchemas , rowWriter, isTopLevel = true )
288305
289306 val code =
290307 code """
0 commit comments