Skip to content

Commit d262b77

Browse files
michalsenkyrwangzejie
authored andcommitted
[SPARK-18891][SQL] Support for Scala Map collection types
## What changes were proposed in this pull request? Add support for arbitrary Scala `Map` types in deserialization as well as a generic implicit encoder. Used the builder approach as in apache#16541 to construct any provided `Map` type upon deserialization. Please note that this PR also adds (ignored) tests for issue [SPARK-19104 CompileException with Map and Case Class in Spark 2.1.0](https://issues.apache.org/jira/browse/SPARK-19104) but doesn't solve it. Added support for Java Maps in codegen code (encoders will be added in a different PR) with the following default implementations for interfaces/abstract classes: * `java.util.Map`, `java.util.AbstractMap` => `java.util.HashMap` * `java.util.SortedMap`, `java.util.NavigableMap` => `java.util.TreeMap` * `java.util.concurrent.ConcurrentMap` => `java.util.concurrent.ConcurrentHashMap` * `java.util.concurrent.ConcurrentNavigableMap` => `java.util.concurrent.ConcurrentSkipListMap` Resulting codegen for `Seq(Map(1 -> 2)).toDS().map(identity).queryExecution.debug.codegen`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private scala.collection.immutable.Map mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ scala.collection.immutable.Map deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ /* 112 */ scala.collection.mutable.Builder CollectObjectsToMap_builderValue5 = scala.collection.immutable.Map$.MODULE$.newBuilder(); /* 113 */ CollectObjectsToMap_builderValue5.sizeHint(deserializetoobject_dataLength); /* 114 */ /* 115 */ int deserializetoobject_loopIndex = 0; /* 116 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 117 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 118 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 119 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 120 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 121 */ /* 122 */ if (CollectObjectsToMap_loopIsNull1) { /* 123 */ throw new RuntimeException("Found null in map key!"); /* 124 */ } /* 125 */ /* 126 */ scala.Tuple2 CollectObjectsToMap_loopValue4; /* 127 */ /* 128 */ if (CollectObjectsToMap_loopIsNull3) { /* 129 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, null); /* 130 */ } else { /* 131 */ CollectObjectsToMap_loopValue4 = new scala.Tuple2(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 132 */ } /* 133 */ /* 134 */ CollectObjectsToMap_builderValue5.$plus$eq(CollectObjectsToMap_loopValue4); /* 135 */ /* 136 */ deserializetoobject_loopIndex += 1; /* 137 */ } /* 138 */ /* 139 */ deserializetoobject_value = (scala.collection.immutable.Map) CollectObjectsToMap_builderValue5.result(); /* 140 */ } /* 141 */ /* 142 */ boolean mapelements_isNull = true; /* 143 */ scala.collection.immutable.Map mapelements_value = null; /* 144 */ if (!false) { /* 145 */ mapelements_argValue = deserializetoobject_value; /* 146 */ /* 147 */ mapelements_isNull = false; /* 148 */ if (!mapelements_isNull) { /* 149 */ Object mapelements_funcResult = null; /* 150 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 151 */ if (mapelements_funcResult == null) { /* 152 */ mapelements_isNull = true; /* 153 */ } else { /* 154 */ mapelements_value = (scala.collection.immutable.Map) mapelements_funcResult; /* 155 */ } /* 156 */ /* 157 */ } /* 158 */ mapelements_isNull = mapelements_value == null; /* 159 */ } /* 160 */ /* 161 */ MapData serializefromobject_value = null; /* 162 */ if (!mapelements_isNull) { /* 163 */ final int serializefromobject_length = mapelements_value.size(); /* 164 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 165 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 166 */ int serializefromobject_index = 0; /* 167 */ final scala.collection.Iterator serializefromobject_entries = mapelements_value.iterator(); /* 168 */ while(serializefromobject_entries.hasNext()) { /* 169 */ final scala.Tuple2 serializefromobject_entry = (scala.Tuple2) serializefromobject_entries.next(); /* 170 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry._1(); /* 171 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry._2(); /* 172 */ /* 173 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 174 */ /* 175 */ if (false) { /* 176 */ throw new RuntimeException("Cannot use null as map key!"); /* 177 */ } else { /* 178 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 179 */ } /* 180 */ /* 181 */ if (false) { /* 182 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 183 */ } else { /* 184 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 185 */ } /* 186 */ /* 187 */ serializefromobject_index++; /* 188 */ } /* 189 */ /* 190 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 191 */ } /* 192 */ serializefromobject_holder.reset(); /* 193 */ /* 194 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 195 */ /* 196 */ if (mapelements_isNull) { /* 197 */ serializefromobject_rowWriter.setNullAt(0); /* 198 */ } else { /* 199 */ // Remember the current cursor so that we can calculate how many bytes are /* 200 */ // written later. /* 201 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 202 */ /* 203 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 204 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 205 */ // grow the global buffer before writing data. /* 206 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 207 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 208 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 209 */ /* 210 */ } else { /* 211 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 212 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 213 */ /* 214 */ // preserve 8 bytes to write the key array numBytes later. /* 215 */ serializefromobject_holder.grow(8); /* 216 */ serializefromobject_holder.cursor += 8; /* 217 */ /* 218 */ // Remember the current cursor so that we can write numBytes of key array later. /* 219 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 220 */ /* 221 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 222 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 223 */ // grow the global buffer before writing data. /* 224 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 225 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 226 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 227 */ /* 228 */ } else { /* 229 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 230 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 231 */ /* 232 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 233 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 234 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 235 */ } else { /* 236 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 237 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 238 */ } /* 239 */ } /* 240 */ } /* 241 */ /* 242 */ // Write the numBytes of key array into the first 8 bytes. /* 243 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 244 */ /* 245 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 246 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 247 */ // grow the global buffer before writing data. /* 248 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 249 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 250 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 251 */ /* 252 */ } else { /* 253 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 254 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 255 */ /* 256 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 257 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 258 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 259 */ } else { /* 260 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 261 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 262 */ } /* 263 */ } /* 264 */ } /* 265 */ /* 266 */ } /* 267 */ /* 268 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 269 */ } /* 270 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 271 */ append(serializefromobject_result); /* 272 */ if (shouldStop()) return; /* 273 */ } /* 274 */ } /* 275 */ } ``` Codegen for `java.util.Map`: ``` /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } /* 004 */ /* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { /* 006 */ private Object[] references; /* 007 */ private scala.collection.Iterator[] inputs; /* 008 */ private scala.collection.Iterator inputadapter_input; /* 009 */ private boolean CollectObjectsToMap_loopIsNull1; /* 010 */ private int CollectObjectsToMap_loopValue0; /* 011 */ private boolean CollectObjectsToMap_loopIsNull3; /* 012 */ private int CollectObjectsToMap_loopValue2; /* 013 */ private UnsafeRow deserializetoobject_result; /* 014 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder deserializetoobject_holder; /* 015 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter deserializetoobject_rowWriter; /* 016 */ private java.util.HashMap mapelements_argValue; /* 017 */ private UnsafeRow mapelements_result; /* 018 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder mapelements_holder; /* 019 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter mapelements_rowWriter; /* 020 */ private UnsafeRow serializefromobject_result; /* 021 */ private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder serializefromobject_holder; /* 022 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter serializefromobject_rowWriter; /* 023 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter; /* 024 */ private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter serializefromobject_arrayWriter1; /* 025 */ /* 026 */ public GeneratedIterator(Object[] references) { /* 027 */ this.references = references; /* 028 */ } /* 029 */ /* 030 */ public void init(int index, scala.collection.Iterator[] inputs) { /* 031 */ partitionIndex = index; /* 032 */ this.inputs = inputs; /* 033 */ wholestagecodegen_init_0(); /* 034 */ wholestagecodegen_init_1(); /* 035 */ /* 036 */ } /* 037 */ /* 038 */ private void wholestagecodegen_init_0() { /* 039 */ inputadapter_input = inputs[0]; /* 040 */ /* 041 */ deserializetoobject_result = new UnsafeRow(1); /* 042 */ this.deserializetoobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(deserializetoobject_result, 32); /* 043 */ this.deserializetoobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(deserializetoobject_holder, 1); /* 044 */ /* 045 */ mapelements_result = new UnsafeRow(1); /* 046 */ this.mapelements_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(mapelements_result, 32); /* 047 */ this.mapelements_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(mapelements_holder, 1); /* 048 */ serializefromobject_result = new UnsafeRow(1); /* 049 */ this.serializefromobject_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(serializefromobject_result, 32); /* 050 */ this.serializefromobject_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(serializefromobject_holder, 1); /* 051 */ this.serializefromobject_arrayWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 052 */ /* 053 */ } /* 054 */ /* 055 */ private void wholestagecodegen_init_1() { /* 056 */ this.serializefromobject_arrayWriter1 = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeArrayWriter(); /* 057 */ /* 058 */ } /* 059 */ /* 060 */ protected void processNext() throws java.io.IOException { /* 061 */ while (inputadapter_input.hasNext() && !stopEarly()) { /* 062 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 063 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 064 */ MapData inputadapter_value = inputadapter_isNull ? null : (inputadapter_row.getMap(0)); /* 065 */ /* 066 */ boolean deserializetoobject_isNull1 = true; /* 067 */ ArrayData deserializetoobject_value1 = null; /* 068 */ if (!inputadapter_isNull) { /* 069 */ deserializetoobject_isNull1 = false; /* 070 */ if (!deserializetoobject_isNull1) { /* 071 */ Object deserializetoobject_funcResult = null; /* 072 */ deserializetoobject_funcResult = inputadapter_value.keyArray(); /* 073 */ if (deserializetoobject_funcResult == null) { /* 074 */ deserializetoobject_isNull1 = true; /* 075 */ } else { /* 076 */ deserializetoobject_value1 = (ArrayData) deserializetoobject_funcResult; /* 077 */ } /* 078 */ /* 079 */ } /* 080 */ deserializetoobject_isNull1 = deserializetoobject_value1 == null; /* 081 */ } /* 082 */ /* 083 */ boolean deserializetoobject_isNull3 = true; /* 084 */ ArrayData deserializetoobject_value3 = null; /* 085 */ if (!inputadapter_isNull) { /* 086 */ deserializetoobject_isNull3 = false; /* 087 */ if (!deserializetoobject_isNull3) { /* 088 */ Object deserializetoobject_funcResult1 = null; /* 089 */ deserializetoobject_funcResult1 = inputadapter_value.valueArray(); /* 090 */ if (deserializetoobject_funcResult1 == null) { /* 091 */ deserializetoobject_isNull3 = true; /* 092 */ } else { /* 093 */ deserializetoobject_value3 = (ArrayData) deserializetoobject_funcResult1; /* 094 */ } /* 095 */ /* 096 */ } /* 097 */ deserializetoobject_isNull3 = deserializetoobject_value3 == null; /* 098 */ } /* 099 */ java.util.HashMap deserializetoobject_value = null; /* 100 */ /* 101 */ if ((deserializetoobject_isNull1 && !deserializetoobject_isNull3) || /* 102 */ (!deserializetoobject_isNull1 && deserializetoobject_isNull3)) { /* 103 */ throw new RuntimeException("Invalid state: Inconsistent nullability of key-value"); /* 104 */ } /* 105 */ /* 106 */ if (!deserializetoobject_isNull1) { /* 107 */ if (deserializetoobject_value1.numElements() != deserializetoobject_value3.numElements()) { /* 108 */ throw new RuntimeException("Invalid state: Inconsistent lengths of key-value arrays"); /* 109 */ } /* 110 */ int deserializetoobject_dataLength = deserializetoobject_value1.numElements(); /* 111 */ java.util.Map CollectObjectsToMap_builderValue5 = new java.util.HashMap(deserializetoobject_dataLength); /* 112 */ /* 113 */ int deserializetoobject_loopIndex = 0; /* 114 */ while (deserializetoobject_loopIndex < deserializetoobject_dataLength) { /* 115 */ CollectObjectsToMap_loopValue0 = (int) (deserializetoobject_value1.getInt(deserializetoobject_loopIndex)); /* 116 */ CollectObjectsToMap_loopValue2 = (int) (deserializetoobject_value3.getInt(deserializetoobject_loopIndex)); /* 117 */ CollectObjectsToMap_loopIsNull1 = deserializetoobject_value1.isNullAt(deserializetoobject_loopIndex); /* 118 */ CollectObjectsToMap_loopIsNull3 = deserializetoobject_value3.isNullAt(deserializetoobject_loopIndex); /* 119 */ /* 120 */ if (CollectObjectsToMap_loopIsNull1) { /* 121 */ throw new RuntimeException("Found null in map key!"); /* 122 */ } /* 123 */ /* 124 */ CollectObjectsToMap_builderValue5.put(CollectObjectsToMap_loopValue0, CollectObjectsToMap_loopValue2); /* 125 */ /* 126 */ deserializetoobject_loopIndex += 1; /* 127 */ } /* 128 */ /* 129 */ deserializetoobject_value = (java.util.HashMap) CollectObjectsToMap_builderValue5; /* 130 */ } /* 131 */ /* 132 */ boolean mapelements_isNull = true; /* 133 */ java.util.HashMap mapelements_value = null; /* 134 */ if (!false) { /* 135 */ mapelements_argValue = deserializetoobject_value; /* 136 */ /* 137 */ mapelements_isNull = false; /* 138 */ if (!mapelements_isNull) { /* 139 */ Object mapelements_funcResult = null; /* 140 */ mapelements_funcResult = ((scala.Function1) references[0]).apply(mapelements_argValue); /* 141 */ if (mapelements_funcResult == null) { /* 142 */ mapelements_isNull = true; /* 143 */ } else { /* 144 */ mapelements_value = (java.util.HashMap) mapelements_funcResult; /* 145 */ } /* 146 */ /* 147 */ } /* 148 */ mapelements_isNull = mapelements_value == null; /* 149 */ } /* 150 */ /* 151 */ MapData serializefromobject_value = null; /* 152 */ if (!mapelements_isNull) { /* 153 */ final int serializefromobject_length = mapelements_value.size(); /* 154 */ final Object[] serializefromobject_convertedKeys = new Object[serializefromobject_length]; /* 155 */ final Object[] serializefromobject_convertedValues = new Object[serializefromobject_length]; /* 156 */ int serializefromobject_index = 0; /* 157 */ final java.util.Iterator serializefromobject_entries = mapelements_value.entrySet().iterator(); /* 158 */ while(serializefromobject_entries.hasNext()) { /* 159 */ final java.util.Map$Entry serializefromobject_entry = (java.util.Map$Entry) serializefromobject_entries.next(); /* 160 */ int ExternalMapToCatalyst_key1 = (Integer) serializefromobject_entry.getKey(); /* 161 */ int ExternalMapToCatalyst_value1 = (Integer) serializefromobject_entry.getValue(); /* 162 */ /* 163 */ boolean ExternalMapToCatalyst_value_isNull1 = false; /* 164 */ /* 165 */ if (false) { /* 166 */ throw new RuntimeException("Cannot use null as map key!"); /* 167 */ } else { /* 168 */ serializefromobject_convertedKeys[serializefromobject_index] = (Integer) ExternalMapToCatalyst_key1; /* 169 */ } /* 170 */ /* 171 */ if (false) { /* 172 */ serializefromobject_convertedValues[serializefromobject_index] = null; /* 173 */ } else { /* 174 */ serializefromobject_convertedValues[serializefromobject_index] = (Integer) ExternalMapToCatalyst_value1; /* 175 */ } /* 176 */ /* 177 */ serializefromobject_index++; /* 178 */ } /* 179 */ /* 180 */ serializefromobject_value = new org.apache.spark.sql.catalyst.util.ArrayBasedMapData(new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedKeys), new org.apache.spark.sql.catalyst.util.GenericArrayData(serializefromobject_convertedValues)); /* 181 */ } /* 182 */ serializefromobject_holder.reset(); /* 183 */ /* 184 */ serializefromobject_rowWriter.zeroOutNullBytes(); /* 185 */ /* 186 */ if (mapelements_isNull) { /* 187 */ serializefromobject_rowWriter.setNullAt(0); /* 188 */ } else { /* 189 */ // Remember the current cursor so that we can calculate how many bytes are /* 190 */ // written later. /* 191 */ final int serializefromobject_tmpCursor = serializefromobject_holder.cursor; /* 192 */ /* 193 */ if (serializefromobject_value instanceof UnsafeMapData) { /* 194 */ final int serializefromobject_sizeInBytes = ((UnsafeMapData) serializefromobject_value).getSizeInBytes(); /* 195 */ // grow the global buffer before writing data. /* 196 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes); /* 197 */ ((UnsafeMapData) serializefromobject_value).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 198 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes; /* 199 */ /* 200 */ } else { /* 201 */ final ArrayData serializefromobject_keys = serializefromobject_value.keyArray(); /* 202 */ final ArrayData serializefromobject_values = serializefromobject_value.valueArray(); /* 203 */ /* 204 */ // preserve 8 bytes to write the key array numBytes later. /* 205 */ serializefromobject_holder.grow(8); /* 206 */ serializefromobject_holder.cursor += 8; /* 207 */ /* 208 */ // Remember the current cursor so that we can write numBytes of key array later. /* 209 */ final int serializefromobject_tmpCursor1 = serializefromobject_holder.cursor; /* 210 */ /* 211 */ if (serializefromobject_keys instanceof UnsafeArrayData) { /* 212 */ final int serializefromobject_sizeInBytes1 = ((UnsafeArrayData) serializefromobject_keys).getSizeInBytes(); /* 213 */ // grow the global buffer before writing data. /* 214 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes1); /* 215 */ ((UnsafeArrayData) serializefromobject_keys).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 216 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes1; /* 217 */ /* 218 */ } else { /* 219 */ final int serializefromobject_numElements = serializefromobject_keys.numElements(); /* 220 */ serializefromobject_arrayWriter.initialize(serializefromobject_holder, serializefromobject_numElements, 4); /* 221 */ /* 222 */ for (int serializefromobject_index1 = 0; serializefromobject_index1 < serializefromobject_numElements; serializefromobject_index1++) { /* 223 */ if (serializefromobject_keys.isNullAt(serializefromobject_index1)) { /* 224 */ serializefromobject_arrayWriter.setNullInt(serializefromobject_index1); /* 225 */ } else { /* 226 */ final int serializefromobject_element = serializefromobject_keys.getInt(serializefromobject_index1); /* 227 */ serializefromobject_arrayWriter.write(serializefromobject_index1, serializefromobject_element); /* 228 */ } /* 229 */ } /* 230 */ } /* 231 */ /* 232 */ // Write the numBytes of key array into the first 8 bytes. /* 233 */ Platform.putLong(serializefromobject_holder.buffer, serializefromobject_tmpCursor1 - 8, serializefromobject_holder.cursor - serializefromobject_tmpCursor1); /* 234 */ /* 235 */ if (serializefromobject_values instanceof UnsafeArrayData) { /* 236 */ final int serializefromobject_sizeInBytes2 = ((UnsafeArrayData) serializefromobject_values).getSizeInBytes(); /* 237 */ // grow the global buffer before writing data. /* 238 */ serializefromobject_holder.grow(serializefromobject_sizeInBytes2); /* 239 */ ((UnsafeArrayData) serializefromobject_values).writeToMemory(serializefromobject_holder.buffer, serializefromobject_holder.cursor); /* 240 */ serializefromobject_holder.cursor += serializefromobject_sizeInBytes2; /* 241 */ /* 242 */ } else { /* 243 */ final int serializefromobject_numElements1 = serializefromobject_values.numElements(); /* 244 */ serializefromobject_arrayWriter1.initialize(serializefromobject_holder, serializefromobject_numElements1, 4); /* 245 */ /* 246 */ for (int serializefromobject_index2 = 0; serializefromobject_index2 < serializefromobject_numElements1; serializefromobject_index2++) { /* 247 */ if (serializefromobject_values.isNullAt(serializefromobject_index2)) { /* 248 */ serializefromobject_arrayWriter1.setNullInt(serializefromobject_index2); /* 249 */ } else { /* 250 */ final int serializefromobject_element1 = serializefromobject_values.getInt(serializefromobject_index2); /* 251 */ serializefromobject_arrayWriter1.write(serializefromobject_index2, serializefromobject_element1); /* 252 */ } /* 253 */ } /* 254 */ } /* 255 */ /* 256 */ } /* 257 */ /* 258 */ serializefromobject_rowWriter.setOffsetAndSize(0, serializefromobject_tmpCursor, serializefromobject_holder.cursor - serializefromobject_tmpCursor); /* 259 */ } /* 260 */ serializefromobject_result.setTotalSize(serializefromobject_holder.totalSize()); /* 261 */ append(serializefromobject_result); /* 262 */ if (shouldStop()) return; /* 263 */ } /* 264 */ } /* 265 */ } ``` ## How was this patch tested? ``` build/mvn -DskipTests clean package && dev/run-tests ``` Additionally in Spark shell: ``` scala> Seq(collection.mutable.HashMap(1 -> 2, 2 -> 3)).toDS().map(_ += (3 -> 4)).collect() res0: Array[scala.collection.mutable.HashMap[Int,Int]] = Array(Map(2 -> 3, 1 -> 2, 3 -> 4)) ``` Author: Michal Senkyr <[email protected]> Author: Michal Šenkýř <[email protected]> Closes apache#16986 from michalsenkyr/dataset-map-builder.
1 parent 9713c7c commit d262b77

File tree

5 files changed

+291
-27
lines changed

5 files changed

+291
-27
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
2020
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
2121
import org.apache.spark.sql.catalyst.expressions._
2222
import org.apache.spark.sql.catalyst.expressions.objects._
23-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
23+
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
2424
import org.apache.spark.sql.types._
2525
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
2626

@@ -335,31 +335,12 @@ object ScalaReflection extends ScalaReflection {
335335
// TODO: add walked type path for map
336336
val TypeRef(_, _, Seq(keyType, valueType)) = t
337337

338-
val keyData =
339-
Invoke(
340-
MapObjects(
341-
p => deserializerFor(keyType, Some(p), walkedTypePath),
342-
Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType),
343-
returnNullable = false),
344-
schemaFor(keyType).dataType),
345-
"array",
346-
ObjectType(classOf[Array[Any]]), returnNullable = false)
347-
348-
val valueData =
349-
Invoke(
350-
MapObjects(
351-
p => deserializerFor(valueType, Some(p), walkedTypePath),
352-
Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType),
353-
returnNullable = false),
354-
schemaFor(valueType).dataType),
355-
"array",
356-
ObjectType(classOf[Array[Any]]), returnNullable = false)
357-
358-
StaticInvoke(
359-
ArrayBasedMapData.getClass,
360-
ObjectType(classOf[scala.collection.immutable.Map[_, _]]),
361-
"toScalaMap",
362-
keyData :: valueData :: Nil)
338+
CollectObjectsToMap(
339+
p => deserializerFor(keyType, Some(p), walkedTypePath),
340+
p => deserializerFor(valueType, Some(p), walkedTypePath),
341+
getPath,
342+
mirror.runtimeClass(t.typeSymbol.asClass)
343+
)
363344

364345
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
365346
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 168 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3030
import org.apache.spark.sql.catalyst.encoders.RowEncoder
3131
import org.apache.spark.sql.catalyst.expressions._
3232
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
33-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
33+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData}
3434
import org.apache.spark.sql.types._
3535

3636
/**
@@ -652,6 +652,173 @@ case class MapObjects private(
652652
}
653653
}
654654

655+
object CollectObjectsToMap {
656+
private val curId = new java.util.concurrent.atomic.AtomicInteger()
657+
658+
/**
659+
* Construct an instance of CollectObjectsToMap case class.
660+
*
661+
* @param keyFunction The function applied on the key collection elements.
662+
* @param valueFunction The function applied on the value collection elements.
663+
* @param inputData An expression that when evaluated returns a map object.
664+
* @param collClass The type of the resulting collection.
665+
*/
666+
def apply(
667+
keyFunction: Expression => Expression,
668+
valueFunction: Expression => Expression,
669+
inputData: Expression,
670+
collClass: Class[_]): CollectObjectsToMap = {
671+
val id = curId.getAndIncrement()
672+
val keyLoopValue = s"CollectObjectsToMap_keyLoopValue$id"
673+
val mapType = inputData.dataType.asInstanceOf[MapType]
674+
val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false)
675+
val valueLoopValue = s"CollectObjectsToMap_valueLoopValue$id"
676+
val valueLoopIsNull = s"CollectObjectsToMap_valueLoopIsNull$id"
677+
val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType)
678+
CollectObjectsToMap(
679+
keyLoopValue, keyFunction(keyLoopVar),
680+
valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
681+
inputData, collClass)
682+
}
683+
}
684+
685+
/**
686+
* Expression used to convert a Catalyst Map to an external Scala Map.
687+
* The collection is constructed using the associated builder, obtained by calling `newBuilder`
688+
* on the collection's companion object.
689+
*
690+
* @param keyLoopValue the name of the loop variable that is used when iterating over the key
691+
* collection, and which is used as input for the `keyLambdaFunction`
692+
* @param keyLambdaFunction A function that takes the `keyLoopVar` as input, and is used as
693+
* a lambda function to handle collection elements.
694+
* @param valueLoopValue the name of the loop variable that is used when iterating over the value
695+
* collection, and which is used as input for the `valueLambdaFunction`
696+
* @param valueLoopIsNull the nullability of the loop variable that is used when iterating over
697+
* the value collection, and which is used as input for the
698+
* `valueLambdaFunction`
699+
* @param valueLambdaFunction A function that takes the `valueLoopVar` as input, and is used as
700+
* a lambda function to handle collection elements.
701+
* @param inputData An expression that when evaluated returns a map object.
702+
* @param collClass The type of the resulting collection.
703+
*/
704+
case class CollectObjectsToMap private(
705+
keyLoopValue: String,
706+
keyLambdaFunction: Expression,
707+
valueLoopValue: String,
708+
valueLoopIsNull: String,
709+
valueLambdaFunction: Expression,
710+
inputData: Expression,
711+
collClass: Class[_]) extends Expression with NonSQLExpression {
712+
713+
override def nullable: Boolean = inputData.nullable
714+
715+
override def children: Seq[Expression] =
716+
keyLambdaFunction :: valueLambdaFunction :: inputData :: Nil
717+
718+
override def eval(input: InternalRow): Any =
719+
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
720+
721+
override def dataType: DataType = ObjectType(collClass)
722+
723+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
724+
// The data with PythonUserDefinedType are actually stored with the data type of its sqlType.
725+
// When we want to apply MapObjects on it, we have to use it.
726+
def inputDataType(dataType: DataType) = dataType match {
727+
case p: PythonUserDefinedType => p.sqlType
728+
case _ => dataType
729+
}
730+
731+
val mapType = inputDataType(inputData.dataType).asInstanceOf[MapType]
732+
val keyElementJavaType = ctx.javaType(mapType.keyType)
733+
ctx.addMutableState(keyElementJavaType, keyLoopValue, "")
734+
val genKeyFunction = keyLambdaFunction.genCode(ctx)
735+
val valueElementJavaType = ctx.javaType(mapType.valueType)
736+
ctx.addMutableState("boolean", valueLoopIsNull, "")
737+
ctx.addMutableState(valueElementJavaType, valueLoopValue, "")
738+
val genValueFunction = valueLambdaFunction.genCode(ctx)
739+
val genInputData = inputData.genCode(ctx)
740+
val dataLength = ctx.freshName("dataLength")
741+
val loopIndex = ctx.freshName("loopIndex")
742+
val tupleLoopValue = ctx.freshName("tupleLoopValue")
743+
val builderValue = ctx.freshName("builderValue")
744+
745+
val getLength = s"${genInputData.value}.numElements()"
746+
747+
val keyArray = ctx.freshName("keyArray")
748+
val valueArray = ctx.freshName("valueArray")
749+
val getKeyArray =
750+
s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();"
751+
val getKeyLoopVar = ctx.getValue(keyArray, inputDataType(mapType.keyType), loopIndex)
752+
val getValueArray =
753+
s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();"
754+
val getValueLoopVar = ctx.getValue(valueArray, inputDataType(mapType.valueType), loopIndex)
755+
756+
// Make a copy of the data if it's unsafe-backed
757+
def makeCopyIfInstanceOf(clazz: Class[_ <: Any], value: String) =
758+
s"$value instanceof ${clazz.getSimpleName}? $value.copy() : $value"
759+
def genFunctionValue(lambdaFunction: Expression, genFunction: ExprCode) =
760+
lambdaFunction.dataType match {
761+
case StructType(_) => makeCopyIfInstanceOf(classOf[UnsafeRow], genFunction.value)
762+
case ArrayType(_, _) => makeCopyIfInstanceOf(classOf[UnsafeArrayData], genFunction.value)
763+
case MapType(_, _, _) => makeCopyIfInstanceOf(classOf[UnsafeMapData], genFunction.value)
764+
case _ => genFunction.value
765+
}
766+
val genKeyFunctionValue = genFunctionValue(keyLambdaFunction, genKeyFunction)
767+
val genValueFunctionValue = genFunctionValue(valueLambdaFunction, genValueFunction)
768+
769+
val valueLoopNullCheck = s"$valueLoopIsNull = $valueArray.isNullAt($loopIndex);"
770+
771+
val builderClass = classOf[Builder[_, _]].getName
772+
val constructBuilder = s"""
773+
$builderClass $builderValue = ${collClass.getName}$$.MODULE$$.newBuilder();
774+
$builderValue.sizeHint($dataLength);
775+
"""
776+
777+
val tupleClass = classOf[(_, _)].getName
778+
val appendToBuilder = s"""
779+
$tupleClass $tupleLoopValue;
780+
781+
if (${genValueFunction.isNull}) {
782+
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, null);
783+
} else {
784+
$tupleLoopValue = new $tupleClass($genKeyFunctionValue, $genValueFunctionValue);
785+
}
786+
787+
$builderValue.$$plus$$eq($tupleLoopValue);
788+
"""
789+
val getBuilderResult = s"${ev.value} = (${collClass.getName}) $builderValue.result();"
790+
791+
val code = s"""
792+
${genInputData.code}
793+
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
794+
795+
if (!${genInputData.isNull}) {
796+
int $dataLength = $getLength;
797+
$constructBuilder
798+
$getKeyArray
799+
$getValueArray
800+
801+
int $loopIndex = 0;
802+
while ($loopIndex < $dataLength) {
803+
$keyLoopValue = ($keyElementJavaType) ($getKeyLoopVar);
804+
$valueLoopValue = ($valueElementJavaType) ($getValueLoopVar);
805+
$valueLoopNullCheck
806+
807+
${genKeyFunction.code}
808+
${genValueFunction.code}
809+
810+
$appendToBuilder
811+
812+
$loopIndex += 1;
813+
}
814+
815+
$getBuilderResult
816+
}
817+
"""
818+
ev.copy(code = code, isNull = genInputData.isNull)
819+
}
820+
}
821+
655822
object ExternalMapToCatalyst {
656823
private val curId = new java.util.concurrent.atomic.AtomicInteger()
657824

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,31 @@ class ScalaReflectionSuite extends SparkFunSuite {
314314
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))
315315
}
316316

317+
test("serialize and deserialize arbitrary map types") {
318+
val mapSerializer = serializerFor[Map[Int, Int]](BoundReference(
319+
0, ObjectType(classOf[Map[Int, Int]]), nullable = false))
320+
assert(mapSerializer.dataType.head.dataType ==
321+
MapType(IntegerType, IntegerType, valueContainsNull = false))
322+
val mapDeserializer = deserializerFor[Map[Int, Int]]
323+
assert(mapDeserializer.dataType == ObjectType(classOf[Map[_, _]]))
324+
325+
import scala.collection.immutable.HashMap
326+
val hashMapSerializer = serializerFor[HashMap[Int, Int]](BoundReference(
327+
0, ObjectType(classOf[HashMap[Int, Int]]), nullable = false))
328+
assert(hashMapSerializer.dataType.head.dataType ==
329+
MapType(IntegerType, IntegerType, valueContainsNull = false))
330+
val hashMapDeserializer = deserializerFor[HashMap[Int, Int]]
331+
assert(hashMapDeserializer.dataType == ObjectType(classOf[HashMap[_, _]]))
332+
333+
import scala.collection.mutable.{LinkedHashMap => LHMap}
334+
val linkedHashMapSerializer = serializerFor[LHMap[Long, String]](BoundReference(
335+
0, ObjectType(classOf[LHMap[Long, String]]), nullable = false))
336+
assert(linkedHashMapSerializer.dataType.head.dataType ==
337+
MapType(LongType, StringType, valueContainsNull = true))
338+
val linkedHashMapDeserializer = deserializerFor[LHMap[Long, String]]
339+
assert(linkedHashMapDeserializer.dataType == ObjectType(classOf[LHMap[_, _]]))
340+
}
341+
317342
private val dataTypeForComplexData = dataTypeFor[ComplexData]
318343
private val typeOfComplexData = typeOf[ComplexData]
319344

sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.Map
2021
import scala.language.implicitConversions
2122
import scala.reflect.runtime.universe.TypeTag
2223

@@ -166,6 +167,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
166167
/** @since 2.2.0 */
167168
implicit def newSequenceEncoder[T <: Seq[_] : TypeTag]: Encoder[T] = ExpressionEncoder()
168169

170+
// Maps
171+
/** @since 2.3.0 */
172+
implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder()
173+
169174
// Arrays
170175

171176
/** @since 1.6.1 */

sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql
1919

2020
import scala.collection.immutable.Queue
21+
import scala.collection.mutable.{LinkedHashMap => LHMap}
2122
import scala.collection.mutable.ArrayBuffer
2223

2324
import org.apache.spark.sql.test.SharedSQLContext
@@ -30,8 +31,14 @@ case class ListClass(l: List[Int])
3031

3132
case class QueueClass(q: Queue[Int])
3233

34+
case class MapClass(m: Map[Int, Int])
35+
36+
case class LHMapClass(m: LHMap[Int, Int])
37+
3338
case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)
3439

40+
case class ComplexMapClass(map: MapClass, lhmap: LHMapClass)
41+
3542
package object packageobject {
3643
case class PackageClass(value: Int)
3744
}
@@ -258,11 +265,90 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
258265
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
259266
}
260267

268+
test("arbitrary maps") {
269+
checkDataset(Seq(Map(1 -> 2)).toDS(), Map(1 -> 2))
270+
checkDataset(Seq(Map(1.toLong -> 2.toLong)).toDS(), Map(1.toLong -> 2.toLong))
271+
checkDataset(Seq(Map(1.toDouble -> 2.toDouble)).toDS(), Map(1.toDouble -> 2.toDouble))
272+
checkDataset(Seq(Map(1.toFloat -> 2.toFloat)).toDS(), Map(1.toFloat -> 2.toFloat))
273+
checkDataset(Seq(Map(1.toByte -> 2.toByte)).toDS(), Map(1.toByte -> 2.toByte))
274+
checkDataset(Seq(Map(1.toShort -> 2.toShort)).toDS(), Map(1.toShort -> 2.toShort))
275+
checkDataset(Seq(Map(true -> false)).toDS(), Map(true -> false))
276+
checkDataset(Seq(Map("test1" -> "test2")).toDS(), Map("test1" -> "test2"))
277+
checkDataset(Seq(Map(Tuple1(1) -> Tuple1(2))).toDS(), Map(Tuple1(1) -> Tuple1(2)))
278+
checkDataset(Seq(Map(1 -> Tuple1(2))).toDS(), Map(1 -> Tuple1(2)))
279+
checkDataset(Seq(Map("test" -> 2.toLong)).toDS(), Map("test" -> 2.toLong))
280+
281+
checkDataset(Seq(LHMap(1 -> 2)).toDS(), LHMap(1 -> 2))
282+
checkDataset(Seq(LHMap(1.toLong -> 2.toLong)).toDS(), LHMap(1.toLong -> 2.toLong))
283+
checkDataset(Seq(LHMap(1.toDouble -> 2.toDouble)).toDS(), LHMap(1.toDouble -> 2.toDouble))
284+
checkDataset(Seq(LHMap(1.toFloat -> 2.toFloat)).toDS(), LHMap(1.toFloat -> 2.toFloat))
285+
checkDataset(Seq(LHMap(1.toByte -> 2.toByte)).toDS(), LHMap(1.toByte -> 2.toByte))
286+
checkDataset(Seq(LHMap(1.toShort -> 2.toShort)).toDS(), LHMap(1.toShort -> 2.toShort))
287+
checkDataset(Seq(LHMap(true -> false)).toDS(), LHMap(true -> false))
288+
checkDataset(Seq(LHMap("test1" -> "test2")).toDS(), LHMap("test1" -> "test2"))
289+
checkDataset(Seq(LHMap(Tuple1(1) -> Tuple1(2))).toDS(), LHMap(Tuple1(1) -> Tuple1(2)))
290+
checkDataset(Seq(LHMap(1 -> Tuple1(2))).toDS(), LHMap(1 -> Tuple1(2)))
291+
checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong))
292+
}
293+
294+
ignore("SPARK-19104: map and product combinations") {
295+
// Case classes
296+
checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
297+
checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3))))
298+
checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> 3)).toDS(), Map(MapClass(Map(1 -> 2)) -> 3))
299+
checkDataset(Seq(Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
300+
Map(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
301+
checkDataset(Seq(LHMap(1 -> MapClass(Map(2 -> 3)))).toDS(), LHMap(1 -> MapClass(Map(2 -> 3))))
302+
checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> 3)).toDS(), LHMap(MapClass(Map(1 -> 2)) -> 3))
303+
checkDataset(Seq(LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4)))).toDS(),
304+
LHMap(MapClass(Map(1 -> 2)) -> MapClass(Map(3 -> 4))))
305+
306+
checkDataset(Seq(LHMapClass(LHMap(1 -> 2))).toDS(), LHMapClass(LHMap(1 -> 2)))
307+
checkDataset(Seq(Map(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
308+
Map(1 -> LHMapClass(LHMap(2 -> 3))))
309+
checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
310+
Map(LHMapClass(LHMap(1 -> 2)) -> 3))
311+
checkDataset(Seq(Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
312+
Map(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
313+
checkDataset(Seq(LHMap(1 -> LHMapClass(LHMap(2 -> 3)))).toDS(),
314+
LHMap(1 -> LHMapClass(LHMap(2 -> 3))))
315+
checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> 3)).toDS(),
316+
LHMap(LHMapClass(LHMap(1 -> 2)) -> 3))
317+
checkDataset(Seq(LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4)))).toDS(),
318+
LHMap(LHMapClass(LHMap(1 -> 2)) -> LHMapClass(LHMap(3 -> 4))))
319+
320+
val complex = ComplexMapClass(MapClass(Map(1 -> 2)), LHMapClass(LHMap(3 -> 4)))
321+
checkDataset(Seq(complex).toDS(), complex)
322+
checkDataset(Seq(Map(1 -> complex)).toDS(), Map(1 -> complex))
323+
checkDataset(Seq(Map(complex -> 5)).toDS(), Map(complex -> 5))
324+
checkDataset(Seq(Map(complex -> complex)).toDS(), Map(complex -> complex))
325+
checkDataset(Seq(LHMap(1 -> complex)).toDS(), LHMap(1 -> complex))
326+
checkDataset(Seq(LHMap(complex -> 5)).toDS(), LHMap(complex -> 5))
327+
checkDataset(Seq(LHMap(complex -> complex)).toDS(), LHMap(complex -> complex))
328+
329+
// Tuples
330+
checkDataset(Seq(Map(1 -> 2) -> Map(3 -> 4)).toDS(), Map(1 -> 2) -> Map(3 -> 4))
331+
checkDataset(Seq(LHMap(1 -> 2) -> Map(3 -> 4)).toDS(), LHMap(1 -> 2) -> Map(3 -> 4))
332+
checkDataset(Seq(Map(1 -> 2) -> LHMap(3 -> 4)).toDS(), Map(1 -> 2) -> LHMap(3 -> 4))
333+
checkDataset(Seq(LHMap(1 -> 2) -> LHMap(3 -> 4)).toDS(), LHMap(1 -> 2) -> LHMap(3 -> 4))
334+
checkDataset(Seq(LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2")))).toDS(),
335+
LHMap((Map("test1" -> 1) -> 2) -> (3 -> LHMap(4 -> "test2"))))
336+
337+
// Complex
338+
checkDataset(Seq(LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))).toDS(),
339+
LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4))))
340+
}
341+
261342
test("nested sequences") {
262343
checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1)))
263344
checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1)))
264345
}
265346

347+
test("nested maps") {
348+
checkDataset(Seq(Map(1 -> LHMap(2 -> 3))).toDS(), Map(1 -> LHMap(2 -> 3)))
349+
checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3))
350+
}
351+
266352
test("package objects") {
267353
import packageobject._
268354
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))

0 commit comments

Comments
 (0)