@@ -32,8 +32,6 @@ import org.apache.spark.sql.types._
3232 */
3333object GenerateUnsafeProjection extends CodeGenerator [Seq [Expression ], UnsafeProjection ] {
3434
35- case class Schema (dataType : DataType , nullable : Boolean )
36-
3735 /** Returns true iff we support this data type. */
3836 def canSupport (dataType : DataType ): Boolean = UserDefinedType .sqlType(dataType) match {
3937 case NullType => true
@@ -45,21 +43,19 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
4543 case _ => false
4644 }
4745
46+ // TODO: if the nullability of field is correct, we can use it to save null check.
4847 private def writeStructToBuffer (
4948 ctx : CodegenContext ,
5049 input : String ,
5150 index : String ,
52- schemas : Seq [Schema ],
51+ fieldTypes : Seq [DataType ],
5352 rowWriter : String ): String = {
5453 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
5554 val tmpInput = ctx.freshName(" tmpInput" )
56- val fieldEvals = schemas.zipWithIndex.map { case (Schema (dt, nullable), i) =>
57- val isNull = if (nullable) {
58- JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " )
59- } else {
60- FalseLiteral
61- }
62- ExprCode (isNull, JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
55+ val fieldEvals = fieldTypes.zipWithIndex.map { case (dt, i) =>
56+ ExprCode (
57+ JavaCode .isNullExpression(s " $tmpInput.isNullAt( $i) " ),
58+ JavaCode .expression(CodeGenerator .getValue(tmpInput, dt, i.toString), dt))
6359 }
6460
6561 val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -74,7 +70,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
7470 | // Remember the current cursor so that we can calculate how many bytes are
7571 | // written later.
7672 | final int $previousCursor = $rowWriter.cursor();
77- | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, schemas , structRowWriter)}
73+ | ${writeExpressionsToBuffer(ctx, tmpInput, fieldEvals, fieldTypes , structRowWriter)}
7874 | $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
7975 |}
8076 """ .stripMargin
@@ -84,7 +80,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
8480 ctx : CodegenContext ,
8581 row : String ,
8682 inputs : Seq [ExprCode ],
87- schemas : Seq [Schema ],
83+ inputTypes : Seq [DataType ],
8884 rowWriter : String ,
8985 isTopLevel : Boolean = false ): String = {
9086 val resetWriter = if (isTopLevel) {
@@ -102,8 +98,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
10298 s " $rowWriter.resetRowWriter(); "
10399 }
104100
105- val writeFields = inputs.zip(schemas ).zipWithIndex.map {
106- case ((input, Schema ( dataType, nullable) ), index) =>
101+ val writeFields = inputs.zip(inputTypes ).zipWithIndex.map {
102+ case ((input, dataType), index) =>
107103 val dt = UserDefinedType .sqlType(dataType)
108104
109105 val setNull = dt match {
@@ -114,7 +110,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
114110 }
115111
116112 val writeField = writeElement(ctx, input.value, index.toString, dt, rowWriter)
117- if (! nullable ) {
113+ if (input.isNull == FalseLiteral ) {
118114 s """
119115 | ${input.code}
120116 | ${writeField.trim}
@@ -147,11 +143,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
147143 """ .stripMargin
148144 }
149145
146+ // TODO: if the nullability of array element is correct, we can use it to save null check.
150147 private def writeArrayToBuffer (
151148 ctx : CodegenContext ,
152149 input : String ,
153150 elementType : DataType ,
154- containsNull : Boolean ,
155151 rowWriter : String ): String = {
156152 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
157153 val tmpInput = ctx.freshName(" tmpInput" )
@@ -174,18 +170,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
174170
175171 val element = CodeGenerator .getValue(tmpInput, et, index)
176172
177- val elementAssignment = if (containsNull) {
178- s """
179- |if ( $tmpInput.isNullAt( $index)) {
180- | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
181- |} else {
182- | ${writeElement(ctx, element, index, et, arrayWriter)}
183- |}
184- """ .stripMargin
185- } else {
186- writeElement(ctx, element, index, et, arrayWriter)
187- }
188-
189173 s """
190174 |final ArrayData $tmpInput = $input;
191175 |if ( $tmpInput instanceof UnsafeArrayData) {
@@ -195,31 +179,30 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
195179 | $arrayWriter.initialize( $numElements);
196180 |
197181 | for (int $index = 0; $index < $numElements; $index++) {
198- | $elementAssignment
182+ | if ( $tmpInput.isNullAt( $index)) {
183+ | $arrayWriter.setNull ${elementOrOffsetSize}Bytes( $index);
184+ | } else {
185+ | ${writeElement(ctx, element, index, et, arrayWriter)}
186+ | }
199187 | }
200188 |}
201189 """ .stripMargin
202190 }
203191
192+ // TODO: if the nullability of value element is correct, we can use it to save null check.
204193 private def writeMapToBuffer (
205194 ctx : CodegenContext ,
206195 input : String ,
207196 index : String ,
208197 keyType : DataType ,
209198 valueType : DataType ,
210- valueContainsNull : Boolean ,
211199 rowWriter : String ): String = {
212200 // Puts `input` in a local variable to avoid to re-evaluate it if it's a statement.
213201 val tmpInput = ctx.freshName(" tmpInput" )
214202 val tmpCursor = ctx.freshName(" tmpCursor" )
215203 val previousCursor = ctx.freshName(" previousCursor" )
216204
217205 // Writes out unsafe map according to the format described in `UnsafeMapData`.
218- val keyArray = writeArrayToBuffer(
219- ctx, s " $tmpInput.keyArray() " , keyType, false , rowWriter)
220- val valueArray = writeArrayToBuffer(
221- ctx, s " $tmpInput.valueArray() " , valueType, valueContainsNull, rowWriter)
222-
223206 s """
224207 |final MapData $tmpInput = $input;
225208 |if ( $tmpInput instanceof UnsafeMapData) {
@@ -236,15 +219,15 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
236219 | // Remember the current cursor so that we can write numBytes of key array later.
237220 | final int $tmpCursor = $rowWriter.cursor();
238221 |
239- | $keyArray
222+ | ${writeArrayToBuffer(ctx, s " $tmpInput . keyArray() " , keyType, rowWriter)}
240223 |
241224 | // Write the numBytes of key array into the first 8 bytes.
242225 | Platform.putLong(
243226 | $rowWriter.getBuffer(),
244227 | $tmpCursor - 8,
245228 | $rowWriter.cursor() - $tmpCursor);
246229 |
247- | $valueArray
230+ | ${writeArrayToBuffer(ctx, s " $tmpInput . valueArray() " , valueType, rowWriter)}
248231 | $rowWriter.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
249232 |}
250233 """ .stripMargin
@@ -257,21 +240,20 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
257240 dt : DataType ,
258241 writer : String ): String = dt match {
259242 case t : StructType =>
260- writeStructToBuffer(
261- ctx, input, index, t.map(e => Schema (e.dataType, e.nullable)), writer)
243+ writeStructToBuffer(ctx, input, index, t.map(_.dataType), writer)
262244
263- case ArrayType (et, en ) =>
245+ case ArrayType (et, _ ) =>
264246 val previousCursor = ctx.freshName(" previousCursor" )
265247 s """
266248 |// Remember the current cursor so that we can calculate how many bytes are
267249 |// written later.
268250 |final int $previousCursor = $writer.cursor();
269- | ${writeArrayToBuffer(ctx, input, et, en, writer)}
251+ | ${writeArrayToBuffer(ctx, input, et, writer)}
270252 | $writer.setOffsetAndSizeFromPreviousCursor( $index, $previousCursor);
271253 """ .stripMargin
272254
273- case MapType (kt, vt, vn ) =>
274- writeMapToBuffer(ctx, input, index, kt, vt, vn, writer)
255+ case MapType (kt, vt, _ ) =>
256+ writeMapToBuffer(ctx, input, index, kt, vt, writer)
275257
276258 case DecimalType .Fixed (precision, scale) =>
277259 s " $writer.write( $index, $input, $precision, $scale); "
@@ -286,11 +268,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
286268 expressions : Seq [Expression ],
287269 useSubexprElimination : Boolean = false ): ExprCode = {
288270 val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination)
289- val exprSchemas = expressions.map(e => Schema (e .dataType, e.nullable) )
271+ val exprTypes = expressions.map(_ .dataType)
290272
291- val numVarLenFields = exprSchemas .count {
292- case Schema (dt, _) => ! UnsafeRow .isFixedLength(dt)
273+ val numVarLenFields = exprTypes .count {
274+ case dt if UnsafeRow .isFixedLength(dt) => false
293275 // TODO: consider large decimal and interval type
276+ case _ => true
294277 }
295278
296279 val rowWriterClass = classOf [UnsafeRowWriter ].getName
@@ -301,7 +284,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
301284 val evalSubexpr = ctx.subexprFunctions.mkString(" \n " )
302285
303286 val writeExpressions = writeExpressionsToBuffer(
304- ctx, ctx.INPUT_ROW , exprEvals, exprSchemas , rowWriter, isTopLevel = true )
287+ ctx, ctx.INPUT_ROW , exprEvals, exprTypes , rowWriter, isTopLevel = true )
305288
306289 val code =
307290 code """
0 commit comments