Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
480a74a
Initial import of code from Databricks unsafe utils repo.
JoshRosen Apr 17, 2015
ab68e08
Begin merging the UTF8String implementations.
JoshRosen Apr 18, 2015
f03e9c1
Play around with Unsafe implementations of more string methods.
JoshRosen Apr 18, 2015
5d55cef
Add skeleton for Row implementation.
JoshRosen Apr 18, 2015
8a8f9df
Add skeleton for GeneratedAggregate integration.
JoshRosen Apr 18, 2015
1ff814d
Add reminder to free memory on iterator completion
JoshRosen Apr 18, 2015
53ba9b7
Start prototyping Java Row -> UnsafeRow converters
JoshRosen Apr 19, 2015
fc4c3a8
Sketch how the converters will be used in UnsafeGeneratedAggregate
JoshRosen Apr 19, 2015
1a483c5
First version that passes some aggregation tests:
JoshRosen Apr 19, 2015
079f1bf
Some clarification of the BytesToBytesMap.lookup() / set() contract.
JoshRosen Apr 19, 2015
f764d13
Simplify address + length calculation in Location.
JoshRosen Apr 19, 2015
c754ae1
Now that the store*() contract has been stregthened, we can remove an…
JoshRosen Apr 19, 2015
ae39694
Add finalizer as "cleanup method of last resort"
JoshRosen Apr 19, 2015
c7f0b56
Reuse UnsafeRow pointer in UnsafeRowConverter
JoshRosen Apr 20, 2015
62ab054
Optimize for fact that get() is only called on String columns.
JoshRosen Apr 20, 2015
c55bf66
Free buffer once iterator has been fully consumed.
JoshRosen Apr 20, 2015
738fa33
Add feature flag to guard UnsafeGeneratedAggregate
JoshRosen Apr 20, 2015
c1b3813
Fix bug in UnsafeMemoryAllocator.free():
JoshRosen Apr 20, 2015
7df6008
Optimizations related to zeroing out memory:
JoshRosen Apr 21, 2015
58ac393
Use UNSAFE allocator in GeneratedAggregate (TODO: make this configura…
JoshRosen Apr 21, 2015
d2bb986
Update to implement new Row methods added upstream
JoshRosen Apr 22, 2015
b3eaccd
Extract aggregation map into its own class.
JoshRosen Apr 22, 2015
bade966
Comment update (bumping to refresh GitHub cache...)
JoshRosen Apr 22, 2015
d85eeff
Add basic sanity test for UnsafeFixedWidthAggregationMap
JoshRosen Apr 22, 2015
1f4b716
Merge Unsafe code into the regular GeneratedAggregate, guarded by a
JoshRosen Apr 22, 2015
92d5a06
Address a number of minor code review comments.
JoshRosen Apr 23, 2015
628f936
Use ints intead of longs for indexing.
JoshRosen Apr 23, 2015
23a440a
Bump up default hash map size
JoshRosen Apr 23, 2015
765243d
Enable optional performance metrics for hash map.
JoshRosen Apr 23, 2015
b26f1d3
Fix bug in murmur hash implementation.
JoshRosen Apr 23, 2015
49aed30
More long -> int conversion.
JoshRosen Apr 23, 2015
29a7575
Remove debug logging
JoshRosen Apr 24, 2015
ef6b3d3
Fix a bunch of FindBugs and IntelliJ inspections
JoshRosen Apr 24, 2015
06e929d
More warning cleanup
JoshRosen Apr 24, 2015
854201a
Import and comment cleanup
JoshRosen Apr 24, 2015
f3dcbfe
More mod replacement
JoshRosen Apr 24, 2015
afe8dca
Some Javadoc cleanup
JoshRosen Apr 24, 2015
a95291e
Cleanups to string handling code
JoshRosen Apr 24, 2015
31eaabc
Lots of TODO and doc cleanup.
JoshRosen Apr 24, 2015
6ffdaa1
Null handling improvements in UnsafeRow.
JoshRosen Apr 24, 2015
9c19fc0
Add configuration options for heap vs. offheap
JoshRosen Apr 24, 2015
cde4132
Add missing pom.xml
JoshRosen Apr 26, 2015
0925847
Disable MiMa checks for new unsafe module
JoshRosen Apr 27, 2015
a8e4a3f
Introduce MemoryManager interface; add to SparkEnv.
JoshRosen Apr 28, 2015
b45f070
Don't redundantly store the offset from key to value, since we can co…
JoshRosen Apr 28, 2015
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@
<artifactId>spark-network-shuffle_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
Expand Down
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
import org.apache.spark.unsafe.memory.{MemoryManager => UnsafeMemoryManager, MemoryAllocator}
import org.apache.spark.util.{RpcUtils, Utils}

/**
Expand Down Expand Up @@ -69,6 +70,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
val unsafeMemoryManager: UnsafeMemoryManager,
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

Expand Down Expand Up @@ -382,6 +384,15 @@ object SparkEnv extends Logging {
new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)

val unsafeMemoryManager: UnsafeMemoryManager = {
val allocator = if (conf.getBoolean("spark.unsafe.offHeap", false)) {
MemoryAllocator.UNSAFE
} else {
MemoryAllocator.HEAP
}
new UnsafeMemoryManager(allocator)
}

val envInstance = new SparkEnv(
executorId,
rpcEnv,
Expand All @@ -398,6 +409,7 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
unsafeMemoryManager,
outputCommitCoordinator,
conf)

Expand Down
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
<module>sql/catalyst</module>
<module>sql/core</module>
<module>sql/hive</module>
<module>unsafe</module>
<module>assembly</module>
<module>external/twitter</module>
<module>external/flume</module>
Expand Down
6 changes: 3 additions & 3 deletions project/SparkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ object BuildCommons {

val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
streamingMqtt, streamingTwitter, streamingZeromq, launcher) =
streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
"streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
"streaming-zeromq", "launcher").map(ProjectRef(buildLocation, _))
"streaming-zeromq", "launcher", "unsafe").map(ProjectRef(buildLocation, _))

val optionallyEnabledProjects@Seq(yarn, yarnStable, java8Tests, sparkGangliaLgpl,
sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
Expand Down Expand Up @@ -159,7 +159,7 @@ object SparkBuild extends PomBuild {
// TODO: Add Sql to mima checks
// TODO: remove launcher from this list after 1.3.
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
networkCommon, networkShuffle, networkYarn, launcher).contains(x)).foreach {
networkCommon, networkShuffle, networkYarn, launcher, unsafe).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}

Expand Down
5 changes: 5 additions & 0 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
<artifactId>scalacheck_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions;

import java.util.Arrays;
import java.util.Iterator;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.map.BytesToBytesMap;
import org.apache.spark.unsafe.memory.MemoryLocation;
import org.apache.spark.unsafe.memory.MemoryManager;

/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
*
* This map supports a maximum of 2 billion keys.
*/
public final class UnsafeFixedWidthAggregationMap {

/**
* An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the
* map, we copy this buffer and use it as the value.
*/
private final long[] emptyAggregationBuffer;

private final StructType aggregationBufferSchema;

private final StructType groupingKeySchema;

/**
* Encodes grouping keys as UnsafeRows.
*/
private final UnsafeRowConverter groupingKeyToUnsafeRowConverter;

/**
* A hashmap which maps from opaque bytearray keys to bytearray values.
*/
private final BytesToBytesMap map;

/**
* Re-used pointer to the current aggregation buffer
*/
private final UnsafeRow currentAggregationBuffer = new UnsafeRow();

/**
* Scratch space that is used when encoding grouping keys into UnsafeRow format.
*
* By default, this is a 1MB array, but it will grow as necessary in case larger keys are
* encountered.
*/
private long[] groupingKeyConversionScratchSpace = new long[1024 / 8];

private final boolean enablePerfMetrics;

/**
* @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema,
* false otherwise.
*/
public static boolean supportsGroupKeySchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}

/**
* @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given
* schema, false otherwise.
*/
public static boolean supportsAggregationBufferSchema(StructType schema) {
for (StructField field: schema.fields()) {
if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) {
return false;
}
}
return true;
}

/**
* Create a new UnsafeFixedWidthAggregationMap.
*
* @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function)
* @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion.
* @param groupingKeySchema the schema of the grouping key, used for row conversion.
* @param groupingKeySchema the memory manager used to allocate our Unsafe memory structures.
* @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing).
* @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact)
*/
public UnsafeFixedWidthAggregationMap(
Row emptyAggregationBuffer,
StructType aggregationBufferSchema,
StructType groupingKeySchema,
MemoryManager memoryManager,
int initialCapacity,
boolean enablePerfMetrics) {
this.emptyAggregationBuffer =
convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema);
this.aggregationBufferSchema = aggregationBufferSchema;
this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics);
this.enablePerfMetrics = enablePerfMetrics;
}

/**
* Convert a Java object row into an UnsafeRow, allocating it into a new long array.
*/
private static long[] convertToUnsafeRow(Row javaRow, StructType schema) {
final UnsafeRowConverter converter = new UnsafeRowConverter(schema);
final long[] unsafeRow = new long[converter.getSizeRequirement(javaRow)];
final long writtenLength =
converter.writeRow(javaRow, unsafeRow, PlatformDependent.LONG_ARRAY_OFFSET);
assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!";
return unsafeRow;
}

/**
* Return the aggregation buffer for the current group. For efficiency, all calls to this method
* return the same object.
*/
public UnsafeRow getAggregationBuffer(Row groupingKey) {
// Zero out the buffer that's used to hold the current row. This is necessary in order
// to ensure that rows hash properly, since garbage data from the previous row could
// otherwise end up as padding in this row.
Arrays.fill(groupingKeyConversionScratchSpace, 0);
final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey);
if (groupingKeySize > groupingKeyConversionScratchSpace.length) {
groupingKeyConversionScratchSpace = new long[groupingKeySize];
}
final long actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow(
groupingKey,
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET);
assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!";

// Probe our map using the serialized key
final BytesToBytesMap.Location loc = map.lookup(
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET,
groupingKeySize);
if (!loc.isDefined()) {
// This is the first time that we've seen this grouping key, so we'll insert a copy of the
// empty aggregation buffer into the map:
loc.putNewKey(
groupingKeyConversionScratchSpace,
PlatformDependent.LONG_ARRAY_OFFSET,
groupingKeySize,
emptyAggregationBuffer,
PlatformDependent.LONG_ARRAY_OFFSET,
emptyAggregationBuffer.length
);
}

// Reset the pointer to point to the value that we just stored or looked up:
final MemoryLocation address = loc.getValueAddress();
currentAggregationBuffer.pointTo(
address.getBaseObject(),
address.getBaseOffset(),
aggregationBufferSchema.length(),
aggregationBufferSchema
);
return currentAggregationBuffer;
}

public static class MapEntry {
public final UnsafeRow key = new UnsafeRow();
public final UnsafeRow value = new UnsafeRow();
}

/**
* Returns an iterator over the keys and values in this map.
*
* For efficiency, each call returns the same object.
*/
public Iterator<MapEntry> iterator() {
return new Iterator<MapEntry>() {

private final MapEntry entry = new MapEntry();
private final Iterator<BytesToBytesMap.Location> mapLocationIterator = map.iterator();

@Override
public boolean hasNext() {
return mapLocationIterator.hasNext();
}

@Override
public MapEntry next() {
final BytesToBytesMap.Location loc = mapLocationIterator.next();
final MemoryLocation keyAddress = loc.getKeyAddress();
final MemoryLocation valueAddress = loc.getValueAddress();
entry.key.pointTo(
keyAddress.getBaseObject(),
keyAddress.getBaseOffset(),
groupingKeySchema.length(),
groupingKeySchema
);
entry.value.pointTo(
valueAddress.getBaseObject(),
valueAddress.getBaseOffset(),
aggregationBufferSchema.length(),
aggregationBufferSchema
);
return entry;
}

@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}

/**
* Free the unsafe memory associated with this map.
*/
public void free() {
map.free();
}

@SuppressWarnings("UseOfSystemOutOrSystemErr")
public void printPerfMetrics() {
if (!enablePerfMetrics) {
throw new IllegalStateException("Perf metrics not enabled");
}
System.out.println("Average probes per lookup: " + map.getAverageProbesPerLookup());
System.out.println("Number of hash collisions: " + map.getNumHashCollisions());
System.out.println("Time spent resizing (ms): " + map.getTimeSpentResizingMs());
System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption());
}

}
Loading