diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 7f8d6c58aec7e..74843806b3ea0 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -177,7 +177,7 @@ public void writeAll(List values) throws Exception { // Deserialize outside synchronized block List list = new ArrayList<>(entry.getValue().size()); - for (Object value : values) { + for (Object value : entry.getValue()) { list.add(serializer.serialize(value)); } serializedValueIter = list.iterator(); @@ -191,6 +191,7 @@ public void writeAll(List values) throws Exception { try (WriteBatch batch = db().createWriteBatch()) { while (valueIter.hasNext()) { + assert serializedValueIter.hasNext(); updateBatch(batch, valueIter.next(), serializedValueIter.next(), klass, naturalIndex, indices); } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java index 4bc2b233fe12d..8c9ac5a232001 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDB.java @@ -209,7 +209,7 @@ public void writeAll(List values) throws Exception { // Deserialize outside synchronized block List list = new ArrayList<>(entry.getValue().size()); - for (Object value : values) { + for (Object value : entry.getValue()) { list.add(serializer.serialize(value)); } serializedValueIter = list.iterator(); @@ -223,6 +223,7 @@ public void writeAll(List values) throws Exception { try (WriteBatch writeBatch = new WriteBatch()) { while (valueIter.hasNext()) { + assert serializedValueIter.hasNext(); updateBatch(writeBatch, valueIter.next(), serializedValueIter.next(), klass, naturalIndex, indices); } diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index c22aea821af35..040ccce70b5a1 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -20,6 +20,7 @@ import java.io.File; import java.lang.ref.Reference; import java.lang.ref.WeakReference; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -422,6 +423,37 @@ public void testResourceCleaner() throws Exception { } } + @Test + public void testMultipleTypesWriteAll() throws Exception { + + List type1List = Arrays.asList( + createCustomType1(1), + createCustomType1(2), + createCustomType1(3), + createCustomType1(4) + ); + + List type2List = Arrays.asList( + createCustomType2(10), + createCustomType2(11), + createCustomType2(12), + createCustomType2(13) + ); + + List fullList = new ArrayList(); + fullList.addAll(type1List); + fullList.addAll(type2List); + + db.writeAll(fullList); + for (CustomType1 value : type1List) { + assertEquals(value, db.read(value.getClass(), value.key)); + } + for (CustomType2 value : type2List) { + assertEquals(value, db.read(value.getClass(), value.key)); + } + } + + private CustomType1 createCustomType1(int i) { CustomType1 t = new CustomType1(); t.key = "key" + i; @@ -432,6 +464,14 @@ private CustomType1 createCustomType1(int i) { return t; } + private CustomType2 createCustomType2(int i) { + CustomType2 t = new CustomType2(); + t.key = "key" + i; + t.id = "id" + i; + t.parentId = "parent_id" + (i / 2); + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0; diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java index 61f18a9a26de7..34a12d8fddec8 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java @@ -20,6 +20,7 @@ import java.io.File; import java.lang.ref.Reference; import java.lang.ref.WeakReference; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -420,6 +421,36 @@ public void testResourceCleaner() throws Exception { } } + @Test + public void testMultipleTypesWriteAll() throws Exception { + + List type1List = Arrays.asList( + createCustomType1(1), + createCustomType1(2), + createCustomType1(3), + createCustomType1(4) + ); + + List type2List = Arrays.asList( + createCustomType2(10), + createCustomType2(11), + createCustomType2(12), + createCustomType2(13) + ); + + List fullList = new ArrayList(); + fullList.addAll(type1List); + fullList.addAll(type2List); + + db.writeAll(fullList); + for (CustomType1 value : type1List) { + assertEquals(value, db.read(value.getClass(), value.key)); + } + for (CustomType2 value : type2List) { + assertEquals(value, db.read(value.getClass(), value.key)); + } + } + private CustomType1 createCustomType1(int i) { CustomType1 t = new CustomType1(); t.key = "key" + i; @@ -430,6 +461,14 @@ private CustomType1 createCustomType1(int i) { return t; } + private CustomType2 createCustomType2(int i) { + CustomType2 t = new CustomType2(); + t.key = "key" + i; + t.id = "id" + i; + t.parentId = "parent_id" + (i / 2); + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0;