Skip to content

Commit 8bc66e6

Browse files
Vladimir KuriatkovVladimir Kuriatkov
authored andcommitted
Java array of structs deserialization fixed
1 parent bac0ff9 commit 8bc66e6

File tree

3 files changed

+226
-2
lines changed

3 files changed

+226
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,10 +271,9 @@ object JavaTypeInference {
271271

272272
case c if listType.isAssignableFrom(typeToken) =>
273273
val et = elementType(typeToken)
274-
MapObjects(
274+
UnresolvedMapObjects(
275275
p => deserializerFor(et, Some(p)),
276276
getPath,
277-
inferDataType(et)._1,
278277
customCollectionCls = Some(c))
279278

280279
case _ if mapType.isAssignableFrom(typeToken) =>
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package test.org.apache.spark.sql;
19+
20+
import java.util.ArrayList;
21+
import java.util.Arrays;
22+
import java.util.Collection;
23+
import java.util.Iterator;
24+
import java.util.List;
25+
26+
import org.junit.After;
27+
import org.junit.Assert;
28+
import org.junit.Before;
29+
import org.junit.Test;
30+
31+
import org.apache.spark.sql.Dataset;
32+
import org.apache.spark.sql.Encoder;
33+
import org.apache.spark.sql.Encoders;
34+
import org.apache.spark.sql.test.TestSparkSession;
35+
import org.apache.spark.sql.types.ArrayType;
36+
import org.apache.spark.sql.types.DataType;
37+
import org.apache.spark.sql.types.DataTypes;
38+
import org.apache.spark.sql.types.Metadata;
39+
import org.apache.spark.sql.types.StructField;
40+
import org.apache.spark.sql.types.StructType;
41+
42+
public class JavaBeanWithArraySuite {
43+
44+
private static final List<Record> RECORDS = new ArrayList<>();
45+
46+
static {
47+
RECORDS.add(new Record(1,
48+
Arrays.asList(new Interval(111, 211), new Interval(121, 221)),
49+
Arrays.asList(11, 21, 31, 41)
50+
));
51+
RECORDS.add(new Record(2,
52+
Arrays.asList(new Interval(112, 212), new Interval(122, 222)),
53+
Arrays.asList(12, 22, 32, 42)
54+
));
55+
RECORDS.add(new Record(3,
56+
Arrays.asList(new Interval(113, 213), new Interval(123, 223)),
57+
Arrays.asList(13, 23, 33, 43)
58+
));
59+
}
60+
61+
private TestSparkSession spark;
62+
63+
@Before
64+
public void setUp() {
65+
spark = new TestSparkSession();
66+
}
67+
68+
@After
69+
public void tearDown() {
70+
spark.stop();
71+
spark = null;
72+
}
73+
74+
@Test
75+
public void testBeanWithArrayFieldsDeserialization() {
76+
77+
StructType schema = createSchema();
78+
Encoder<Record> encoder = Encoders.bean(Record.class);
79+
80+
Dataset<Record> dataset = spark
81+
.read()
82+
.format("json")
83+
.schema(schema)
84+
.load("src/test/resources/test-data/with-array-fields")
85+
.as(encoder);
86+
87+
List<Record> records = dataset.collectAsList();
88+
89+
Assert.assertTrue(Util.equals(records, RECORDS));
90+
}
91+
92+
private static StructType createSchema() {
93+
StructField[] intervalFields = {
94+
new StructField("startTime", DataTypes.LongType, true, Metadata.empty()),
95+
new StructField("endTime", DataTypes.LongType, true, Metadata.empty())
96+
};
97+
DataType intervalType = new StructType(intervalFields);
98+
99+
DataType intervalsType = new ArrayType(intervalType, true);
100+
101+
DataType valuesType = new ArrayType(DataTypes.IntegerType, true);
102+
103+
StructField[] fields = {
104+
new StructField("id", DataTypes.IntegerType, true, Metadata.empty()),
105+
new StructField("intervals", intervalsType, true, Metadata.empty()),
106+
new StructField("values", valuesType, true, Metadata.empty())
107+
};
108+
return new StructType(fields);
109+
}
110+
111+
public static class Record {
112+
113+
private int id;
114+
private List<Interval> intervals;
115+
private List<Integer> values;
116+
117+
public Record() { }
118+
119+
Record(int id, List<Interval> intervals, List<Integer> values) {
120+
this.id = id;
121+
this.intervals = intervals;
122+
this.values = values;
123+
}
124+
125+
public int getId() {
126+
return id;
127+
}
128+
129+
public void setId(int id) {
130+
this.id = id;
131+
}
132+
133+
public List<Interval> getIntervals() {
134+
return intervals;
135+
}
136+
137+
public void setIntervals(List<Interval> intervals) {
138+
this.intervals = intervals;
139+
}
140+
141+
public List<Integer> getValues() {
142+
return values;
143+
}
144+
145+
public void setValues(List<Integer> values) {
146+
this.values = values;
147+
}
148+
149+
@Override
150+
public boolean equals(Object obj) {
151+
if (!(obj instanceof Record)) return false;
152+
Record other = (Record) obj;
153+
return
154+
(other.id == this.id) &&
155+
Util.equals(other.intervals, this.intervals) &&
156+
Util.equals(other.values, this.values);
157+
}
158+
159+
@Override
160+
public String toString() {
161+
return String.format("{ id: %d, intervals: %s }", id, intervals );
162+
}
163+
}
164+
165+
public static class Interval {
166+
167+
private long startTime;
168+
private long endTime;
169+
170+
public Interval() { }
171+
172+
Interval(long startTime, long endTime) {
173+
this.startTime = startTime;
174+
this.endTime = endTime;
175+
}
176+
177+
public long getStartTime() {
178+
return startTime;
179+
}
180+
181+
public void setStartTime(long startTime) {
182+
this.startTime = startTime;
183+
}
184+
185+
public long getEndTime() {
186+
return endTime;
187+
}
188+
189+
public void setEndTime(long endTime) {
190+
this.endTime = endTime;
191+
}
192+
193+
@Override
194+
public boolean equals(Object obj) {
195+
if (!(obj instanceof Interval)) return false;
196+
Interval other = (Interval) obj;
197+
return
198+
(other.startTime == this.startTime) &&
199+
(other.endTime == this.endTime);
200+
}
201+
202+
@Override
203+
public String toString() {
204+
return String.format("[%d,%d]", startTime, endTime);
205+
}
206+
}
207+
208+
private static class Util {
209+
210+
private static <E> boolean equals(Collection<E> as, Collection<E> bs) {
211+
if (as.size() != bs.size()) return false;
212+
Iterator<E> ai = as.iterator();
213+
Iterator<E> bi = bs.iterator();
214+
while (ai.hasNext() && bi.hasNext()) {
215+
if (!ai.next().equals(bi.next())) {
216+
return false;
217+
}
218+
}
219+
return true;
220+
}
221+
}
222+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{ "id": 1, "intervals": [{ "startTime": 111, "endTime": 211 }, { "startTime": 121, "endTime": 221 }], "values": [11, 21, 31, 41]}
2+
{ "id": 2, "intervals": [{ "startTime": 112, "endTime": 212 }, { "startTime": 122, "endTime": 222 }], "values": [12, 22, 32, 42]}
3+
{ "id": 3, "intervals": [{ "startTime": 113, "endTime": 213 }, { "startTime": 123, "endTime": 223 }], "values": [13, 23, 33, 43]}

0 commit comments

Comments
 (0)