Skip to content

Conversation

@michalsenkyr
Copy link
Contributor

@michalsenkyr michalsenkyr commented Feb 18, 2017

What changes were proposed in this pull request?

Add support for arbitrary Scala Map types in deserialization as well as a generic implicit encoder.

Used the builder approach as in #16541 to construct any provided Map type upon deserialization.

Please note that this PR also adds (ignored) tests for issue SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0 but doesn't solve it.

Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes:

  • java.util.Map, java.util.AbstractMap => java.util.HashMap
  • java.util.SortedMap, java.util.NavigableMap => java.util.TreeMap
  • java.util.concurrent.ConcurrentMap => java.util.concurrent.ConcurrentHashMap
  • java.util.concurrent.ConcurrentNavigableMap => java.util.concurrent.ConcurrentSkipListMap

Resulting codegen for Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen:

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private scala.collection.immutable.Map mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       scala.collection.immutable.Map deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */
/* 112 */         scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder();
/* 113 */         CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
/* 114 */
/* 115 */         int deserializetoobject_loopIndex = 0;
/* 116 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 117 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 118 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 119 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 120 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 121 */
/* 122 */           if (CollectObjectsToMap_loopIsNull1) {
/* 123 */             throw new RuntimeException("Found null in map key!");
/* 124 */           }
/* 125 */
/* 126 */           scala.Tuple2 CollectObjectsToMap_loopValue4;
/* 127 */
/* 128 */           if (CollectObjectsToMap_loopIsNull3) {
/* 129 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null);
/* 130 */           } else {
/* 131 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 132 */           }
/* 133 */
/* 134 */           CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
/* 135 */
/* 136 */           deserializetoobject_loopIndex += 1;
/* 137 */         }
/* 138 */
/* 139 */         deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
/* 140 */       }
/* 141 */
/* 142 */       boolean mapelements_isNull = true;
/* 143 */       scala.collection.immutable.Map mapelements_value = null;
/* 144 */       if (!false) {
/* 145 */         mapelements_argValue = deserializetoobject_value;
/* 146 */
/* 147 */         mapelements_isNull = false;
/* 148 */         if (!mapelements_isNull) {
/* 149 */           Object mapelements_funcResult = null;
/* 150 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 151 */           if (mapelements_funcResult == null) {
/* 152 */             mapelements_isNull = true;
/* 153 */           } else {
/* 154 */             mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult;
/* 155 */           }
/* 156 */
/* 157 */         }
/* 158 */         mapelements_isNull = mapelements_value == null;
/* 159 */       }
/* 160 */
/* 161 */       MapData serializefromobject_value = null;
/* 162 */       if (!mapelements_isNull) {
/* 163 */         final int serializefromobject_length = mapelements_value.size();
/* 164 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 165 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 166 */         int serializefromobject_index = 0;
/* 167 */         final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator();
/* 168 */         while(serializefromobject_entries.hasNext()) {
/* 169 */           final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next();
/* 170 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1();
/* 171 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2();
/* 172 */
/* 173 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 174 */
/* 175 */           if (false) {
/* 176 */             throw new RuntimeException("Cannot use null as map key!");
/* 177 */           } else {
/* 178 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 179 */           }
/* 180 */
/* 181 */           if (false) {
/* 182 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 183 */           } else {
/* 184 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 185 */           }
/* 186 */
/* 187 */           serializefromobject_index++;
/* 188 */         }
/* 189 */
/* 190 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 191 */       }
/* 192 */       serializefromobject_holder.reset();
/* 193 */
/* 194 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 195 */
/* 196 */       if (mapelements_isNull) {
/* 197 */         serializefromobject_rowWriter.setNullAt(0);
/* 198 */       } else {
/* 199 */         // Remember the current cursor so that we can calculate how many bytes are
/* 200 */         // written later.
/* 201 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 202 */
/* 203 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 204 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 205 */           // grow the global buffer before writing data.
/* 206 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 207 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 208 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 209 */
/* 210 */         } else {
/* 211 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 212 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 213 */
/* 214 */           // preserve 8 bytes to write the key array numBytes later.
/* 215 */           serializefromobject_holder.grow(8);
/* 216 */           serializefromobject_holder.cursor += 8;
/* 217 */
/* 218 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 219 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 220 */
/* 221 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 222 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 223 */             // grow the global buffer before writing data.
/* 224 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 225 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 226 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 227 */
/* 228 */           } else {
/* 229 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 230 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 231 */
/* 232 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 233 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 234 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 235 */               } else {
/* 236 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 237 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 238 */               }
/* 239 */             }
/* 240 */           }
/* 241 */
/* 242 */           // Write the numBytes of key array into the first 8 bytes.
/* 243 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 244 */
/* 245 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 246 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 247 */             // grow the global buffer before writing data.
/* 248 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 249 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 250 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 251 */
/* 252 */           } else {
/* 253 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 254 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 255 */
/* 256 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 257 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 258 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 259 */               } else {
/* 260 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 261 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 262 */               }
/* 263 */             }
/* 264 */           }
/* 265 */
/* 266 */         }
/* 267 */
/* 268 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 269 */       }
/* 270 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 271 */       append(serializefromobject_result);
/* 272 */       if (shouldStop()) return;
/* 273 */     }
/* 274 */   }
/* 275 */ }

Codegen for java.util.Map:

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private java.util.HashMap mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       java.util.HashMap deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */         java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength);
/* 112 */
/* 113 */         int deserializetoobject_loopIndex = 0;
/* 114 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 115 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 116 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 117 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 118 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 119 */
/* 120 */           if (CollectObjectsToMap_loopIsNull1) {
/* 121 */             throw new RuntimeException("Found null in map key!");
/* 122 */           }
/* 123 */
/* 124 */           CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 125 */
/* 126 */           deserializetoobject_loopIndex += 1;
/* 127 */         }
/* 128 */
/* 129 */         deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5;
/* 130 */       }
/* 131 */
/* 132 */       boolean mapelements_isNull = true;
/* 133 */       java.util.HashMap mapelements_value = null;
/* 134 */       if (!false) {
/* 135 */         mapelements_argValue = deserializetoobject_value;
/* 136 */
/* 137 */         mapelements_isNull = false;
/* 138 */         if (!mapelements_isNull) {
/* 139 */           Object mapelements_funcResult = null;
/* 140 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 141 */           if (mapelements_funcResult == null) {
/* 142 */             mapelements_isNull = true;
/* 143 */           } else {
/* 144 */             mapelements_value = (java.util.HashMap) mapelements_funcResult;
/* 145 */           }
/* 146 */
/* 147 */         }
/* 148 */         mapelements_isNull = mapelements_value == null;
/* 149 */       }
/* 150 */
/* 151 */       MapData serializefromobject_value = null;
/* 152 */       if (!mapelements_isNull) {
/* 153 */         final int serializefromobject_length = mapelements_value.size();
/* 154 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 155 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 156 */         int serializefromobject_index = 0;
/* 157 */         final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator();
/* 158 */         while(serializefromobject_entries.hasNext()) {
/* 159 */           final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next();
/* 160 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey();
/* 161 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue();
/* 162 */
/* 163 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 164 */
/* 165 */           if (false) {
/* 166 */             throw new RuntimeException("Cannot use null as map key!");
/* 167 */           } else {
/* 168 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 169 */           }
/* 170 */
/* 171 */           if (false) {
/* 172 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 173 */           } else {
/* 174 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 175 */           }
/* 176 */
/* 177 */           serializefromobject_index++;
/* 178 */         }
/* 179 */
/* 180 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 181 */       }
/* 182 */       serializefromobject_holder.reset();
/* 183 */
/* 184 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 185 */
/* 186 */       if (mapelements_isNull) {
/* 187 */         serializefromobject_rowWriter.setNullAt(0);
/* 188 */       } else {
/* 189 */         // Remember the current cursor so that we can calculate how many bytes are
/* 190 */         // written later.
/* 191 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 192 */
/* 193 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 194 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 195 */           // grow the global buffer before writing data.
/* 196 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 197 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 198 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 199 */
/* 200 */         } else {
/* 201 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 202 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 203 */
/* 204 */           // preserve 8 bytes to write the key array numBytes later.
/* 205 */           serializefromobject_holder.grow(8);
/* 206 */           serializefromobject_holder.cursor += 8;
/* 207 */
/* 208 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 209 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 210 */
/* 211 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 212 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 213 */             // grow the global buffer before writing data.
/* 214 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 215 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 216 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 217 */
/* 218 */           } else {
/* 219 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 220 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 221 */
/* 222 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 223 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 224 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 225 */               } else {
/* 226 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 227 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 228 */               }
/* 229 */             }
/* 230 */           }
/* 231 */
/* 232 */           // Write the numBytes of key array into the first 8 bytes.
/* 233 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 234 */
/* 235 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 236 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 237 */             // grow the global buffer before writing data.
/* 238 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 239 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 240 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 241 */
/* 242 */           } else {
/* 243 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 244 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 245 */
/* 246 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 247 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 248 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 249 */               } else {
/* 250 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 251 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 252 */               }
/* 253 */             }
/* 254 */           }
/* 255 */
/* 256 */         }
/* 257 */
/* 258 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 259 */       }
/* 260 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 261 */       append(serializefromobject_result);
/* 262 */       if (shouldStop()) return;
/* 263 */     }
/* 264 */   }
/* 265 */ }

How was this patch tested?

build/mvn -DskipTests clean package && dev/run-tests

Additionally in Spark shell:

scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect()
res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4))

@michalsenkyr
Copy link
Contributor Author

michalsenkyr commented Feb 26, 2017

Added support for Java Maps with support for pre-allocation (capacity argument on constructor) and sensible defaults for interfaces/abstract classes. Also includes implicit encoders.
Updated codegen in description (only a cosmetic change), added codegen for Java Map and listed all defined defaults.

@michalsenkyr
Copy link
Contributor Author

Rebased onto the current master and integrated a few minor changes from the code review of #16541 in case anyone is still interested in this feature

UnresolvedMapObjects(mapFunction, getPath, Some(cls))

case t if t <:< localTypeOf[Map[_, _]] =>
case t if t <:< localTypeOf[Map[_, _]] || t <:< localTypeOf[java.util.Map[_, _]] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should handle java map in JavaTypeInference, but I think it's better to do it in another PR and focus on scala map in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. Should I remove the case condition modifications in this PR or leave them as they are and remove them in the next one?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's remove them and related java map tests in this PR and add them in next PR

@michalsenkyr michalsenkyr changed the title [SPARK-18891][SQL] Support for Map collection types [SPARK-18891][SQL] Support for Scala Map collection types May 19, 2017
@cloud-fan
Copy link
Contributor

sorry it conflicts...

@michalsenkyr
Copy link
Contributor Author

That was because of my other PR that just got accepted. Just a matter of appending unit tests. I resolved the conflict from browser for now. Can rebase later if merge commits are not encouraged in PRs.

collClass: Class[_]): CollectObjectsToMap = {
val id = curId.getAndIncrement()
val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
val keyLoopIsNull = s"CollectObjectsToMap_keyLoopIsNull$id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this? the map key can not be null by definition

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. A key in MapData cannot be null. However, since the function takes two ArrayDatas as input, I figured that we shouldn't count on this requirement being necessarily fulfilled. As CollectObjectsToMap is a class separate from its usage in ScalaReflection, I tried to make it as generic and as similar to MapObjects as possible, so it can be used elsewhere without having to make sure additional preconditions are met.
It also produces a generic Map which has implementations that can support null keys. Right now, the only check that prevents this is here. If there is ever a need to support these kinds of Maps in the future, this should make the job easier.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tbh I think this new expression should only be used to turn a catalyst map to external map, and we don't need to generalize it. We can even let it only accept a map type input, instead of 2 array inputs.

keyData :: valueData :: Nil)
CollectObjectsToMap(
p => deserializerFor(keyType, Some(p), walkedTypePath),
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should set returnNullable to false here

case class CollectObjectsToMap private(
keyLoopValue: String,
keyLoopIsNull: String,
keyLoopVarDataType: DataType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need key/value data types as parameters? We can easily get them from key/value input data expression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modelled this class after the MapObjects class so that they could be used similarly. I noticed that since then a new UnresolvedMapObjects class was introduced which also doesn't require the element data type. Would this be something similar? And if so, shouldn't I rather introduce a new UnresolvedCollectObjectsToMap class instead?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UnresolvedMapObject is used for dynamic type mapping of array element, but we don't need this for map element.

val loopIndex = ctx.freshName("loopIndex")

// In RowEncoder, we use `Object` to represent Array or Seq, so we need to determine the type
// of input collection at runtime for this case.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this here. The key/value arrays come from MapData.getKeyArray, so there is no need to determine the type at runtime because it's always ArrayData

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As mentioned earlier, I tried to make this class as generic and similar to MapObjects as possible so it can be used elsewhere without certain preconditions being met. Granted that getting sequences here in the future is unlikely. Should I remove it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea let remove it, see #16986 (comment)

val keyInputDataType = inputDataType(keyInputData)
val valueInputDataType = inputDataType(valueInputData)

def lengthAndLoopVar(inputDataType: DataType, genInputData: ExprCode,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can simplify this method to only handle ArrayType

val builderClass = classOf[java.util.Map[_, _]].getName
// Check for constructor with capacity specification
if (Try(cls.getConstructor(Integer.TYPE)).isSuccess) {
s"$builderClass $builderValue = new ${cls.getName}($dataLength);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use the customer map type as the type declaration, then we don't need the cast in https://github.com/apache/spark/pull/16986/files#diff-e436c96ea839dfe446837ab2a3531f93R901

$tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
}

$builderValue.$$plus$$eq($tupleLoopValue);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is ok, but it will be great if there is a way to avoid creating the tuple every time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, MapBuilder only accepts Tuples

Use customer map type for Java Map deserialization
Replace key/value input data with map input data
Remove null check for keys
Remove unneeded code for sequences
@michalsenkyr
Copy link
Contributor Author

So I tried to simplify the code as much as possible, removing unneeded parameters. I must admit I am not entirely sure about whether I am handling all the data types correctly but everything seems to work.

$builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
$builderValue.sizeHint($dataLength);
"""
// Java Map, AbstractMap => HashMap
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR focus on scala map, we don't even have a test for java map. Let's remove these and add them back in the follow-up PR.

}

val (appendToBuilder, getBuilderResult) =
if (classOf[scala.collection.Map[_, _]].isAssignableFrom(collClass)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, let's focus on scala map for now.

implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()

// Maps
/** @since 2.2.0 */
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it needs to be 2.3 now.

val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
val tupleLoopVar = s"CollectObjectsToMap_tupleLoopValue$id"
val builderValue = s"CollectObjectsToMap_builderValue$id"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We generate name for keyLoopVar and valueLoopVar here because they are used in the keyFunction and valueFunction. The tupleLoopVar and builderValue don't have this problem and we can generate them in class CollectObjectsToMap

private val curId = new java.util.concurrent.atomic.AtomicInteger()

/**
* Construct an instance of CollectObjects case class.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CollectObjects -> CollectObjectsToMap

@cloud-fan
Copy link
Contributor

ping @michalsenkyr

Bump version in scaladoc
Minor alterations based on code review
@cloud-fan
Copy link
Contributor

ok to test

}

/**
* An equivalent to the [[MapObjects]] case class but returning an ObjectType containing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the class doc to explicitly say that this expression is used to convert a catalyst map to external map.


// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
// When we want to apply MapObjects on it, we have to use it.
def inputDataType(dataType: DataType) = dataType match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code in MapObejcts is:

    val inputDataType = inputData.dataType match {
      case p: PythonUserDefinedType => p.sqlType
      case _ => inputData.dataType
    }

We should call this before we do val mapType = inputData.dataType.asInstanceOf[MapType]

case _ => dataType
}

def lengthAndLoopVar(elementType: DataType, genInputData: ExprCode, method: String,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's just 2 lines method, can we inline it?

val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)

val valueLoopNullCheck =
s"$valueLoopIsNull = ${genInputData.value}.valueArray().isNullAt($loopIndex);"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about $valueArray.isNullAt($loopIndex)?

${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};

if (!${genInputData.isNull}) {
if ($getKeyLength != $getValueLength) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't need a keyLength and valueLength, just have a mapLength which can be calculated by MapData.numElements

checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
}

ignore("SPARK-19104: map and product combinations") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ignore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these tests for issue SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0 as I thought I could fix it as part of this PR. However, I found out that it was a more complicated issue than I anticipated so I left the tests there and ignored them. I can remove them.

checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2"))
checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2)))
checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2)))
checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some nested map cases? e.g. Map(1 -> LHMap(2 -> 3))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added as a separate test case (same as sequences)

Use size of map instead of length of key/value arrays
Add Python UDT resolution to map type
Add nested map tests
Update scaladoc
@SparkQA
Copy link

SparkQA commented Jun 10, 2017

Test build #77880 has finished for PR 16986 at commit dbdcb9c.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Jun 11, 2017

Test build #77882 has finished for PR 16986 at commit e37e0ca.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

Copy link
Contributor

@cloud-fan cloud-fan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, you can address these style comments in your next bug fix PR, thanks for your contribution!

s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
val getValueArray =
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ArrayData is automatically imported by the codegen framework, so we can just write ArrayData here

val tupleLoopValue = ctx.freshName("tupleLoopValue")
val builderValue = ctx.freshName("builderValue")

val getLength = s"${genInputData.value}.numElements()"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can inline this getLength

int $dataLength = $getLength;
$constructBuilder
$getKeyArray
$getValueArray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can also inline getKeyArray and getValueArray

while ($loopIndex < $dataLength) {
$keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
$valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
$valueLoopNullCheck
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also inline this. The principle is, we should inline these simple codes as many as possible, then when you look at this code block, it's more clear what's going on.

ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}

test("arbitrary maps") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this suite is DatasetPrimitiveSuite, we should move the list/seq/map tests to a new suite DatasetComplexTypeSuite

@cloud-fan
Copy link
Contributor

merging to master!

@asfgit asfgit closed this in 0538f3b Jun 12, 2017
dataknocker pushed a commit to dataknocker/spark that referenced this pull request Jun 16, 2017
## What changes were proposed in this pull request?

Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder.

Used the builder approach as in apache#16541 to construct any provided `Map` type upon deserialization.

Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it.

Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes:

* `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap`
* `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap`
* `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap`
* `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap`

Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private scala.collection.immutable.Map mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       scala.collection.immutable.Map deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */
/* 112 */         scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder();
/* 113 */         CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength);
/* 114 */
/* 115 */         int deserializetoobject_loopIndex = 0;
/* 116 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 117 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 118 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 119 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 120 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 121 */
/* 122 */           if (CollectObjectsToMap_loopIsNull1) {
/* 123 */             throw new RuntimeException("Found null in map key!");
/* 124 */           }
/* 125 */
/* 126 */           scala.Tuple2 CollectObjectsToMap_loopValue4;
/* 127 */
/* 128 */           if (CollectObjectsToMap_loopIsNull3) {
/* 129 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null);
/* 130 */           } else {
/* 131 */             CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 132 */           }
/* 133 */
/* 134 */           CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4);
/* 135 */
/* 136 */           deserializetoobject_loopIndex += 1;
/* 137 */         }
/* 138 */
/* 139 */         deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result();
/* 140 */       }
/* 141 */
/* 142 */       boolean mapelements_isNull = true;
/* 143 */       scala.collection.immutable.Map mapelements_value = null;
/* 144 */       if (!false) {
/* 145 */         mapelements_argValue = deserializetoobject_value;
/* 146 */
/* 147 */         mapelements_isNull = false;
/* 148 */         if (!mapelements_isNull) {
/* 149 */           Object mapelements_funcResult = null;
/* 150 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 151 */           if (mapelements_funcResult == null) {
/* 152 */             mapelements_isNull = true;
/* 153 */           } else {
/* 154 */             mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult;
/* 155 */           }
/* 156 */
/* 157 */         }
/* 158 */         mapelements_isNull = mapelements_value == null;
/* 159 */       }
/* 160 */
/* 161 */       MapData serializefromobject_value = null;
/* 162 */       if (!mapelements_isNull) {
/* 163 */         final int serializefromobject_length = mapelements_value.size();
/* 164 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 165 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 166 */         int serializefromobject_index = 0;
/* 167 */         final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator();
/* 168 */         while(serializefromobject_entries.hasNext()) {
/* 169 */           final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next();
/* 170 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1();
/* 171 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2();
/* 172 */
/* 173 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 174 */
/* 175 */           if (false) {
/* 176 */             throw new RuntimeException("Cannot use null as map key!");
/* 177 */           } else {
/* 178 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 179 */           }
/* 180 */
/* 181 */           if (false) {
/* 182 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 183 */           } else {
/* 184 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 185 */           }
/* 186 */
/* 187 */           serializefromobject_index++;
/* 188 */         }
/* 189 */
/* 190 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 191 */       }
/* 192 */       serializefromobject_holder.reset();
/* 193 */
/* 194 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 195 */
/* 196 */       if (mapelements_isNull) {
/* 197 */         serializefromobject_rowWriter.setNullAt(0);
/* 198 */       } else {
/* 199 */         // Remember the current cursor so that we can calculate how many bytes are
/* 200 */         // written later.
/* 201 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 202 */
/* 203 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 204 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 205 */           // grow the global buffer before writing data.
/* 206 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 207 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 208 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 209 */
/* 210 */         } else {
/* 211 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 212 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 213 */
/* 214 */           // preserve 8 bytes to write the key array numBytes later.
/* 215 */           serializefromobject_holder.grow(8);
/* 216 */           serializefromobject_holder.cursor += 8;
/* 217 */
/* 218 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 219 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 220 */
/* 221 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 222 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 223 */             // grow the global buffer before writing data.
/* 224 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 225 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 226 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 227 */
/* 228 */           } else {
/* 229 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 230 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 231 */
/* 232 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 233 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 234 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 235 */               } else {
/* 236 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 237 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 238 */               }
/* 239 */             }
/* 240 */           }
/* 241 */
/* 242 */           // Write the numBytes of key array into the first 8 bytes.
/* 243 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 244 */
/* 245 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 246 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 247 */             // grow the global buffer before writing data.
/* 248 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 249 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 250 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 251 */
/* 252 */           } else {
/* 253 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 254 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 255 */
/* 256 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 257 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 258 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 259 */               } else {
/* 260 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 261 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 262 */               }
/* 263 */             }
/* 264 */           }
/* 265 */
/* 266 */         }
/* 267 */
/* 268 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 269 */       }
/* 270 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 271 */       append(serializefromobject_result);
/* 272 */       if (shouldStop()) return;
/* 273 */     }
/* 274 */   }
/* 275 */ }
```

Codegen for `java.util.Map`:

```
/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator inputadapter_input;
/* 009 */   private boolean CollectObjectsToMap_loopIsNull1;
/* 010 */   private int CollectObjectsToMap_loopValue0;
/* 011 */   private boolean CollectObjectsToMap_loopIsNull3;
/* 012 */   private int CollectObjectsToMap_loopValue2;
/* 013 */   private UnsafeRow deserializetoobject_result;
/* 014 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder;
/* 015 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter;
/* 016 */   private java.util.HashMap mapelements_argValue;
/* 017 */   private UnsafeRow mapelements_result;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder;
/* 019 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter;
/* 020 */   private UnsafeRow serializefromobject_result;
/* 021 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder;
/* 022 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter;
/* 023 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter;
/* 024 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1;
/* 025 */
/* 026 */   public GeneratedIterator(Object[] references) {
/* 027 */     this.references = references;
/* 028 */   }
/* 029 */
/* 030 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 031 */     partitionIndex = index;
/* 032 */     this.inputs = inputs;
/* 033 */     wholestagecodegen_init_0();
/* 034 */     wholestagecodegen_init_1();
/* 035 */
/* 036 */   }
/* 037 */
/* 038 */   private void wholestagecodegen_init_0() {
/* 039 */     inputadapter_input = inputs[0];
/* 040 */
/* 041 */     deserializetoobject_result = new UnsafeRow(1);
/* 042 */     this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32);
/* 043 */     this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1);
/* 044 */
/* 045 */     mapelements_result = new UnsafeRow(1);
/* 046 */     this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32);
/* 047 */     this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1);
/* 048 */     serializefromobject_result = new UnsafeRow(1);
/* 049 */     this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32);
/* 050 */     this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1);
/* 051 */     this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 052 */
/* 053 */   }
/* 054 */
/* 055 */   private void wholestagecodegen_init_1() {
/* 056 */     this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter();
/* 057 */
/* 058 */   }
/* 059 */
/* 060 */   protected void processNext() throws java.io.IOException {
/* 061 */     while (inputadapter_input.hasNext() && !stopEarly()) {
/* 062 */       InternalRow inputadapter_row = (InternalRow) inputadapter_input.next();
/* 063 */       boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 064 */       MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0));
/* 065 */
/* 066 */       boolean deserializetoobject_isNull1 = true;
/* 067 */       ArrayData deserializetoobject_value1 = null;
/* 068 */       if (!inputadapter_isNull) {
/* 069 */         deserializetoobject_isNull1 = false;
/* 070 */         if (!deserializetoobject_isNull1) {
/* 071 */           Object deserializetoobject_funcResult = null;
/* 072 */           deserializetoobject_funcResult = inputadapter_value.keyArray();
/* 073 */           if (deserializetoobject_funcResult == null) {
/* 074 */             deserializetoobject_isNull1 = true;
/* 075 */           } else {
/* 076 */             deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult;
/* 077 */           }
/* 078 */
/* 079 */         }
/* 080 */         deserializetoobject_isNull1 = deserializetoobject_value1 == null;
/* 081 */       }
/* 082 */
/* 083 */       boolean deserializetoobject_isNull3 = true;
/* 084 */       ArrayData deserializetoobject_value3 = null;
/* 085 */       if (!inputadapter_isNull) {
/* 086 */         deserializetoobject_isNull3 = false;
/* 087 */         if (!deserializetoobject_isNull3) {
/* 088 */           Object deserializetoobject_funcResult1 = null;
/* 089 */           deserializetoobject_funcResult1 = inputadapter_value.valueArray();
/* 090 */           if (deserializetoobject_funcResult1 == null) {
/* 091 */             deserializetoobject_isNull3 = true;
/* 092 */           } else {
/* 093 */             deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1;
/* 094 */           }
/* 095 */
/* 096 */         }
/* 097 */         deserializetoobject_isNull3 = deserializetoobject_value3 == null;
/* 098 */       }
/* 099 */       java.util.HashMap deserializetoobject_value = null;
/* 100 */
/* 101 */       if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) ||
/* 102 */         (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) {
/* 103 */         throw new RuntimeException("Invalid state: Inconsistent nullability of key-value");
/* 104 */       }
/* 105 */
/* 106 */       if (!deserializetoobject_isNull1) {
/* 107 */         if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) {
/* 108 */           throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays");
/* 109 */         }
/* 110 */         int deserializetoobject_dataLength = deserializetoobject_value1.numElements();
/* 111 */         java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength);
/* 112 */
/* 113 */         int deserializetoobject_loopIndex = 0;
/* 114 */         while (deserializetoobject_loopIndex < deserializetoobject_dataLength) {
/* 115 */           CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex));
/* 116 */           CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex));
/* 117 */           CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex);
/* 118 */           CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex);
/* 119 */
/* 120 */           if (CollectObjectsToMap_loopIsNull1) {
/* 121 */             throw new RuntimeException("Found null in map key!");
/* 122 */           }
/* 123 */
/* 124 */           CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2);
/* 125 */
/* 126 */           deserializetoobject_loopIndex += 1;
/* 127 */         }
/* 128 */
/* 129 */         deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5;
/* 130 */       }
/* 131 */
/* 132 */       boolean mapelements_isNull = true;
/* 133 */       java.util.HashMap mapelements_value = null;
/* 134 */       if (!false) {
/* 135 */         mapelements_argValue = deserializetoobject_value;
/* 136 */
/* 137 */         mapelements_isNull = false;
/* 138 */         if (!mapelements_isNull) {
/* 139 */           Object mapelements_funcResult = null;
/* 140 */           mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue);
/* 141 */           if (mapelements_funcResult == null) {
/* 142 */             mapelements_isNull = true;
/* 143 */           } else {
/* 144 */             mapelements_value = (java.util.HashMap) mapelements_funcResult;
/* 145 */           }
/* 146 */
/* 147 */         }
/* 148 */         mapelements_isNull = mapelements_value == null;
/* 149 */       }
/* 150 */
/* 151 */       MapData serializefromobject_value = null;
/* 152 */       if (!mapelements_isNull) {
/* 153 */         final int serializefromobject_length = mapelements_value.size();
/* 154 */         final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length];
/* 155 */         final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length];
/* 156 */         int serializefromobject_index = 0;
/* 157 */         final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator();
/* 158 */         while(serializefromobject_entries.hasNext()) {
/* 159 */           final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next();
/* 160 */           int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey();
/* 161 */           int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue();
/* 162 */
/* 163 */           boolean ExternalMapToCatalyst_value_isNull1 = false;
/* 164 */
/* 165 */           if (false) {
/* 166 */             throw new RuntimeException("Cannot use null as map key!");
/* 167 */           } else {
/* 168 */             serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1;
/* 169 */           }
/* 170 */
/* 171 */           if (false) {
/* 172 */             serializefromobject_convertedValues[serializefromobject_index] = null;
/* 173 */           } else {
/* 174 */             serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1;
/* 175 */           }
/* 176 */
/* 177 */           serializefromobject_index++;
/* 178 */         }
/* 179 */
/* 180 */         serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues));
/* 181 */       }
/* 182 */       serializefromobject_holder.reset();
/* 183 */
/* 184 */       serializefromobject_rowWriter.zeroOutNullBytes();
/* 185 */
/* 186 */       if (mapelements_isNull) {
/* 187 */         serializefromobject_rowWriter.setNullAt(0);
/* 188 */       } else {
/* 189 */         // Remember the current cursor so that we can calculate how many bytes are
/* 190 */         // written later.
/* 191 */         final int serializefromobject_tmpCursor = serializefromobject_holder.cursor;
/* 192 */
/* 193 */         if (serializefromobject_value instanceof UnsafeMapData) {
/* 194 */           final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes();
/* 195 */           // grow the global buffer before writing data.
/* 196 */           serializefromobject_holder.grow(serializefromobject_sizeInBytes);
/* 197 */           ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 198 */           serializefromobject_holder.cursor += serializefromobject_sizeInBytes;
/* 199 */
/* 200 */         } else {
/* 201 */           final ArrayData serializefromobject_keys = serializefromobject_value.keyArray();
/* 202 */           final ArrayData serializefromobject_values = serializefromobject_value.valueArray();
/* 203 */
/* 204 */           // preserve 8 bytes to write the key array numBytes later.
/* 205 */           serializefromobject_holder.grow(8);
/* 206 */           serializefromobject_holder.cursor += 8;
/* 207 */
/* 208 */           // Remember the current cursor so that we can write numBytes of key array later.
/* 209 */           final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor;
/* 210 */
/* 211 */           if (serializefromobject_keys instanceof UnsafeArrayData) {
/* 212 */             final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes();
/* 213 */             // grow the global buffer before writing data.
/* 214 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes1);
/* 215 */             ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 216 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes1;
/* 217 */
/* 218 */           } else {
/* 219 */             final int serializefromobject_numElements = serializefromobject_keys.numElements();
/* 220 */             serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4);
/* 221 */
/* 222 */             for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) {
/* 223 */               if (serializefromobject_keys.isNullAt(serializefromobject_index1)) {
/* 224 */                 serializefromobject_arrayWriter.setNullInt(serializefromobject_index1);
/* 225 */               } else {
/* 226 */                 final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1);
/* 227 */                 serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element);
/* 228 */               }
/* 229 */             }
/* 230 */           }
/* 231 */
/* 232 */           // Write the numBytes of key array into the first 8 bytes.
/* 233 */           Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1);
/* 234 */
/* 235 */           if (serializefromobject_values instanceof UnsafeArrayData) {
/* 236 */             final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes();
/* 237 */             // grow the global buffer before writing data.
/* 238 */             serializefromobject_holder.grow(serializefromobject_sizeInBytes2);
/* 239 */             ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor);
/* 240 */             serializefromobject_holder.cursor += serializefromobject_sizeInBytes2;
/* 241 */
/* 242 */           } else {
/* 243 */             final int serializefromobject_numElements1 = serializefromobject_values.numElements();
/* 244 */             serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4);
/* 245 */
/* 246 */             for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) {
/* 247 */               if (serializefromobject_values.isNullAt(serializefromobject_index2)) {
/* 248 */                 serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2);
/* 249 */               } else {
/* 250 */                 final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2);
/* 251 */                 serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1);
/* 252 */               }
/* 253 */             }
/* 254 */           }
/* 255 */
/* 256 */         }
/* 257 */
/* 258 */         serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor);
/* 259 */       }
/* 260 */       serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize());
/* 261 */       append(serializefromobject_result);
/* 262 */       if (shouldStop()) return;
/* 263 */     }
/* 264 */   }
/* 265 */ }
```

## How was this patch tested?

```
build/mvn -DskipTests clean package && dev/run-tests
```

Additionally in Spark shell:

```
scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect()
res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4))
```

Author: Michal Senkyr <[email protected]>
Author: Michal Šenkýř <[email protected]>

Closes apache#16986 from michalsenkyr/dataset-map-builder.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants