diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 121dfbd4f6838..e21a737837465 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -20,6 +20,7 @@ import java.io.File; import java.io.IOException; import java.lang.ref.SoftReference; +import java.nio.ByteBuffer; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentLinkedQueue; @@ -31,6 +32,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; +import com.google.common.collect.Iterables; import org.fusesource.leveldbjni.JniDBFactory; import org.iq80.leveldb.DB; import org.iq80.leveldb.Options; @@ -154,7 +156,7 @@ public void write(Object value) throws Exception { try (WriteBatch batch = db().createWriteBatch()) { byte[] data = serializer.serialize(value); synchronized (ti) { - updateBatch(batch, value, data, value.getClass(), ti.naturalIndex(), ti.indices()); + updateBatch(batch, value, data, value.getClass(), ti.naturalIndex(), ti.indices(), null); db().write(batch); } } @@ -164,35 +166,44 @@ public void writeAll(List values) throws Exception { Preconditions.checkArgument(values != null && !values.isEmpty(), "Non-empty values required."); - // Group by class, in case there are values from different classes in the values + // Group by class, in case there are values from different classes in the values. // Typical usecase is for this to be a single class. // A NullPointerException will be thrown if values contain null object. for (Map.Entry, ? extends List> entry : values.stream().collect(Collectors.groupingBy(Object::getClass)).entrySet()) { - - final Iterator valueIter = entry.getValue().iterator(); - final Iterator serializedValueIter; - - // Deserialize outside synchronized block - List list = new ArrayList<>(entry.getValue().size()); - for (Object value : values) { - list.add(serializer.serialize(value)); - } - serializedValueIter = list.iterator(); - final Class klass = entry.getKey(); - final LevelDBTypeInfo ti = getTypeInfo(klass); - synchronized (ti) { - final LevelDBTypeInfo.Index naturalIndex = ti.naturalIndex(); - final Collection indices = ti.indices(); + // Partition the large value list to a set of smaller batches, to reduce the memory + // pressure caused by serialization and give fairness to other writing threads. + for (List batchList : Iterables.partition(entry.getValue(), 128)) { + final Iterator valueIter = batchList.iterator(); + final Iterator serializedValueIter; - try (WriteBatch batch = db().createWriteBatch()) { - while (valueIter.hasNext()) { - updateBatch(batch, valueIter.next(), serializedValueIter.next(), klass, - naturalIndex, indices); + // Deserialize outside synchronized block + List serializedValueList = new ArrayList<>(batchList.size()); + for (Object value : batchList) { + serializedValueList.add(serializer.serialize(value)); + } + serializedValueIter = serializedValueList.iterator(); + + final LevelDBTypeInfo ti = getTypeInfo(klass); + synchronized (ti) { + final LevelDBTypeInfo.Index naturalIndex = ti.naturalIndex(); + final Collection indices = ti.indices(); + + try (WriteBatch batch = db().createWriteBatch()) { + // A hash map to update the delta of each countKey, wrap countKey with type byte[] + // as ByteBuffer because ByteBuffer is comparable. + Map counts = new HashMap<>(); + while (valueIter.hasNext()) { + updateBatch(batch, valueIter.next(), serializedValueIter.next(), klass, + naturalIndex, indices, counts); + } + for (Map.Entry countEntry : counts.entrySet()) { + naturalIndex.updateCount(batch, countEntry.getKey().array(), countEntry.getValue()); + } + db().write(batch); } - db().write(batch); } } } @@ -204,7 +215,8 @@ private void updateBatch( byte[] data, Class klass, LevelDBTypeInfo.Index naturalIndex, - Collection indices) throws Exception { + Collection indices, + Map counts) throws Exception { Object existing; try { existing = get(naturalIndex.entityKey(null, value), klass); @@ -216,7 +228,7 @@ private void updateBatch( byte[] naturalKey = naturalIndex.toKey(naturalIndex.getValue(value)); for (LevelDBTypeInfo.Index idx : indices) { byte[] prefix = cache.getPrefix(idx); - idx.add(batch, value, existing, data, naturalKey, prefix); + idx.add(batch, value, existing, data, naturalKey, prefix, counts); } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java index d7423537ddfcf..654730f899508 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBTypeInfo.java @@ -18,6 +18,7 @@ package org.apache.spark.util.kvstore; import java.lang.reflect.Array; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.HashMap; import java.util.Map; @@ -314,7 +315,7 @@ byte[] entityKey(byte[] prefix, Object entity) throws Exception { return entityKey; } - private void updateCount(WriteBatch batch, byte[] key, long delta) { + void updateCount(WriteBatch batch, byte[] key, long delta) { long updated = getCount(key) + delta; if (updated > 0) { batch.put(key, db.serializer.serialize(updated)); @@ -323,13 +324,19 @@ private void updateCount(WriteBatch batch, byte[] key, long delta) { } } + private void updateCount(Map counts, byte[] key, long delta) { + ByteBuffer buffer = ByteBuffer.wrap(key); + counts.put(buffer, counts.getOrDefault(buffer, 0L) + delta); + } + private void addOrRemove( WriteBatch batch, Object entity, Object existing, byte[] data, byte[] naturalKey, - byte[] prefix) throws Exception { + byte[] prefix, + Map counts) throws Exception { Object indexValue = getValue(entity); Preconditions.checkNotNull(indexValue, "Null index value for %s in type %s.", name, type.getName()); @@ -376,7 +383,11 @@ private void addOrRemove( // end markers for the indexed value. if (!isChild()) { byte[] oldCountKey = end(null, oldIndexedValue); - updateCount(batch, oldCountKey, -1L); + if (counts != null) { + updateCount(counts, oldCountKey, -1L); + } else { + updateCount(batch, oldCountKey, -1L); + } needCountUpdate = true; } } @@ -392,7 +403,11 @@ private void addOrRemove( if (needCountUpdate && !isChild()) { long delta = data != null ? 1L : -1L; byte[] countKey = isNatural ? end(prefix) : end(prefix, indexValue); - updateCount(batch, countKey, delta); + if (counts != null) { + updateCount(counts, countKey, delta); + } else { + updateCount(batch, countKey, delta); + } } } @@ -405,6 +420,7 @@ private void addOrRemove( * @param data Serialized entity to store (when storing the entity, not a reference). * @param naturalKey The value's natural key (to avoid re-computing it for every index). * @param prefix The parent index prefix, if this is a child index. + * @param counts A hash map to update the delta of each countKey, used when calling writeAll. */ void add( WriteBatch batch, @@ -412,8 +428,9 @@ void add( Object existing, byte[] data, byte[] naturalKey, - byte[] prefix) throws Exception { - addOrRemove(batch, entity, existing, data, naturalKey, prefix); + byte[] prefix, + Map counts) throws Exception { + addOrRemove(batch, entity, existing, data, naturalKey, prefix, counts); } /** @@ -429,7 +446,7 @@ void remove( Object entity, byte[] naturalKey, byte[] prefix) throws Exception { - addOrRemove(batch, entity, null, null, naturalKey, prefix); + addOrRemove(batch, entity, null, null, naturalKey, prefix, null); } long getCount(byte[] key) { diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java index f6566617765d4..bf3f497a23cbf 100644 --- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.util.kvstore; import java.io.File; +import java.util.ArrayList; import java.util.Arrays; import java.util.Iterator; import java.util.List; @@ -200,6 +201,28 @@ public void testUpdate() throws Exception { assertEquals(0, db.count(t.getClass(), "name", "name")); } + @Test + public void testWriteAll() throws Exception { + List values = new ArrayList<>(); + for (int i = 0; i < 2; i++) { + CustomType1 t1 = createCustomType1(i); + values.add(t1); + + CustomType2 t2 = createCustomType2(i, i); + values.add(t2); + } + + CustomType1 t = createCustomType1(3); + t.id = "id1"; // test count of "id" index + values.add(t); + + assertEquals(5, values.size()); + db.writeAll(values); + assertEquals(3, db.count(CustomType1.class)); + assertEquals(2, db.count(CustomType1.class, "id", t.id)); + assertEquals(2, db.count(CustomType2.class)); + } + @Test public void testRemoveAll() throws Exception { for (int i = 0; i < 2; i++) { @@ -322,6 +345,14 @@ private CustomType1 createCustomType1(int i) { return t; } + private CustomType2 createCustomType2(int i, int parentId) { + CustomType2 t = new CustomType2(); + t.key = "key" + i; + t.id = "id" + i; + t.parentId = "parentId" + parentId; + return t; + } + private int countKeys(Class type) throws Exception { byte[] prefix = db.getTypeInfo(type).keyPrefix(); int count = 0;