From 06bf9f994f1a0f654e85d3e6ff267f1ca106880f Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 19 Dec 2018 12:05:37 -0800 Subject: [PATCH 01/30] initial merge --- .../protocol/OpenShufflePartition.java | 77 +++++++++ .../UploadShufflePartitionStream.java | 112 ++++++++++++++ .../external/ExternalShuffleDataIO.java | 58 +++++++ .../ExternalShuffleMapOutputWriter.java | 82 ++++++++++ .../ExternalShufflePartitionReader.java | 94 +++++++++++ .../external/ExternalShuffleReadSupport.java | 62 ++++++++ .../external/ExternalShuffleWriteSupport.java | 89 +++++++++++ .../ShuffleServiceAddressProvider.scala | 31 ++++ ...ShuffleServiceAddressProviderFactory.scala | 28 ++++ ...ernetesShuffleServiceAddressProvider.scala | 146 ++++++++++++++++++ ...ShuffleServiceAddressProviderFactory.scala | 48 ++++++ 11 files changed, 827 insertions(+) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala create mode 100644 core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java new file mode 100644 index 000000000000..02b09a83cfb7 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java @@ -0,0 +1,77 @@ +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +public class OpenShufflePartition extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + public final int partitionId; + + public OpenShufflePartition(String appId, String execId, int shuffleId, int mapId, int partitionId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenShufflePartition) { + OpenShufflePartition o = (OpenShufflePartition) other; + return Objects.equal(appId, o.appId) + && execId == o.execId + && shuffleId == o.shuffleId + && mapId == o.mapId + && partitionId == o.partitionId; + } + return false; + } + + @Override + protected Type type() { + return null; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, shuffleId, mapId, partitionId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + buf.writeInt(partitionId); + } + + public static OpenShufflePartition decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + int partitionId = buf.readInt(); + return new OpenShufflePartition(appId, execId, shuffleId, mapId, partitionId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java new file mode 100644 index 000000000000..0a6c78e27c1a --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +/** + * Upload shuffle partition request to the External Shuffle Service. + * This request should also include the driverHostPort for the sake of + * setting up a driver heartbeat to monitor heartbeat + */ +public class UploadShufflePartitionStream extends BlockTransferMessage { + public final String driverHostPort; + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + public final int partitionId; + + public UploadShufflePartitionStream( + String driverHostPort, + String appId, + String execId, + int shuffleId, + int mapId, + int partitionId) { + this.driverHostPort = driverHostPort; + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShufflePartitionStream) { + UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + return Objects.equal(appId, o.appId) + && driverHostPort == o.driverHostPort + && execId == o.execId + && shuffleId == o.shuffleId + && mapId == o.mapId + && partitionId == o.partitionId; + } + return false; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_PARTITION_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode(driverHostPort, appId, execId, shuffleId, mapId, partitionId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("driverHostPort", driverHostPort) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(driverHostPort) + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, driverHostPort); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + buf.writeInt(partitionId); + } + + public static UploadShufflePartitionStream decode(ByteBuf buf) { + String driverHostPort = Encoders.Strings.decode(buf); + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + int partitionId = buf.readInt(); + return new UploadShufflePartitionStream(driverHostPort, appId, execId, shuffleId, mapId, partitionId); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java new file mode 100644 index 000000000000..3ce83201665a --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -0,0 +1,58 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShuffleDataIO; +import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.apache.spark.shuffle.api.ShuffleWriteSupport; +import org.apache.spark.SecurityManager; +import org.apache.spark.util.Utils; + +public class ExternalShuffleDataIO implements ShuffleDataIO { + + private static final String SHUFFLE_SERVICE_PORT_CONFIG = "spark.shuffle.service.port"; + private static final String DEFAULT_SHUFFLE_PORT = "7337"; + + private final SparkConf sparkConf; + private final TransportConf conf; + private final SecurityManager securityManager; + private final String hostname; + private final int port; + private final String execId; + + public ExternalShuffleDataIO( + SparkConf sparkConf) { + this.sparkConf = sparkConf; + this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 2); + this.securityManager = SparkEnv.get().securityManager(); + this.hostname = SparkEnv.get().blockManager().blockTransferService().hostName(); + + int tmpPort = Integer.parseInt( + Utils.getSparkOrYarnConfig(sparkConf, SHUFFLE_SERVICE_PORT_CONFIG, DEFAULT_SHUFFLE_PORT)); + if (tmpPort == 0) { + this.port = Integer.parseInt(sparkConf.get(SHUFFLE_SERVICE_PORT_CONFIG)); + } else { + this.port = tmpPort; + } + this.execId = SparkEnv.get().blockManager().shuffleServerId().executorId(); + } + + @Override + public void initialize() { + // TODO: hmmmm? maybe register? idk + } + + @Override + public ShuffleReadSupport readSupport() { + return new ExternalShuffleReadSupport( + conf, securityManager.isAuthenticationEnabled(), securityManager, hostname, port, execId); + } + + @Override + public ShuffleWriteSupport writeSupport() { + return new ExternalShuffleWriteSupport( + conf, securityManager.isAuthenticationEnabled(), securityManager, hostname, port, execId); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java new file mode 100644 index 000000000000..c75b0b9dbc56 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -0,0 +1,82 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; + +public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); + + private final TransportClient client; + private final String appId; + private final String execId; + private final int shuffleId; + private final int mapId; + private final int partitionId; + + private long totalLength = 0; + + public ExternalShufflePartitionWriter( + TransportClient client, + String appId, + String execId, + int shuffleId, + int mapId, + int partitionId) { + this.client = client; + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.partitionId = partitionId; + } + + @Override + public void appendBytesToPartition(InputStream streamReadingBytesToAppend) { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully uploaded partition"); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Encountered an error uploading partition", e); + } + }; + try { + ByteBuffer streamHeader = + new UploadShufflePartitionStream(this.appId, execId, shuffleId, mapId, partitionId).toByteBuffer(); + int avaibleSize = streamReadingBytesToAppend.available(); + byte[] buf = new byte[avaibleSize]; + int size = streamReadingBytesToAppend.read(buf, 0, avaibleSize); + assert size == avaibleSize; + ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); + client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); + totalLength += size; + } catch (Exception e) { + logger.error("Encountered error while attempting to upload partition to ESS", e); + throw new RuntimeException(e); + } + } + + @Override + public long commitAndGetTotalLength() { + return totalLength; + } + + @Override + public void abort(Exception failureReason) { + // TODO + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java new file mode 100644 index 000000000000..75630566b1da --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -0,0 +1,94 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.StreamCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; +import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; +import org.apache.spark.network.shuffle.protocol.StreamHandle; +import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.util.ByteBufferInputStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.SequenceInputStream; +import java.nio.ByteBuffer; +import java.util.Vector; + +public class ExternalShufflePartitionReader implements ShufflePartitionReader { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionReader.class); + + private final TransportClient client; + private final String appId; + private final String execId; + private final int shuffleId; + private final int mapId; + + public ExternalShufflePartitionReader(TransportClient client, String appId, String execId, int shuffleId, int mapId) { + this.client = client; + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public InputStream fetchPartition(int reduceId) { + OpenShufflePartition openMessage = new OpenShufflePartition(appId, execId, shuffleId, mapId, reduceId); + + ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000 /* what should be the default? */); + + try { + StreamCombiningCallback callback = new StreamCombiningCallback(); + StreamHandle streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); + for (int i = 0; i < streamHandle.numChunks; i++) { + client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), + callback); + } + return callback.getCombinedInputStream(); + } catch (Exception e) { + logger.error("Encountered exception while trying to fetch blocks", e); + throw new RuntimeException(e); + } + } + + private class StreamCombiningCallback implements StreamCallback { + + public boolean failed; + public final Vector inputStreams; + + public StreamCombiningCallback() { + inputStreams = new Vector<>(); + failed = false; + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + inputStreams.add(new ByteBufferInputStream(buf)); + } + + @Override + public void onComplete(String streamId) throws IOException { + // do nothing + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + failed = true; + for (InputStream stream : inputStreams) { + stream.close(); + } + } + + public SequenceInputStream getCombinedInputStream() { + if (failed) { + throw new RuntimeException("Stream chunk gathering failed"); + } + return new SequenceInputStream(inputStreams.elements()); + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java new file mode 100644 index 000000000000..7951a5318816 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -0,0 +1,62 @@ +package org.apache.spark.shuffle.external; + +import com.google.common.collect.Lists; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.crypto.AuthClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class ExternalShuffleReadSupport implements ShuffleReadSupport { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleReadSupport.class); + + private final TransportConf conf; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + private final String hostname; + private final int port; + private final String execId; + + public ExternalShuffleReadSupport( + TransportConf conf, + boolean authEnabled, + SecretKeyHolder secretKeyHolder, + String hostname, + int port, + String execId) { + this.conf = conf; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + this.hostname = hostname; + this.port = port; + this.execId = execId; + } + + @Override + public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) { + // TODO combine this into a function with ExternalShuffleWriteSupport + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + try { + TransportClient client = clientFactory.createClient(hostname, port); + return new ExternalShufflePartitionReader(client, appId, execId, shuffleId, mapId); + } catch (Exception e) { + logger.error("Encountered error while creating transport client"); + throw new RuntimeException(e); // what is standard practice here? + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java new file mode 100644 index 000000000000..102268371cd0 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -0,0 +1,89 @@ +package org.apache.spark.shuffle.external; + +import com.google.common.collect.Lists; +import org.apache.spark.network.TransportContext; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientBootstrap; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.crypto.AuthClientBootstrap; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.util.TransportConf; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.ShuffleWriteSupport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { + + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); + + private final TransportConf conf; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + private final String hostname; + private final int port; + private final String execId; + + public ExternalShuffleWriteSupport( + TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, String hostname, int port, String execId) { + this.conf = conf; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + this.hostname = hostname; + this.port = port; + this.execId = execId; + } + + @Override + public ShufflePartitionWriter newPartitionWriter(String appId, int shuffleId, int mapId, int partitionId) { + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + try { + TransportClient client = clientFactory.createClient(hostname, port); + return new ExternalShufflePartitionWriter(client, appId, execId, shuffleId, mapId, partitionId); + } catch (Exception e) { + logger.error("Encountered error while creating transport client"); + throw new RuntimeException(e); // what is standard practice here? + } + } + + @Override + public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); + } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + return new ShuffleMapOutputWriter() { + @Override + public ShufflePartitionWriter newPartitionWriter(int partitionId) { + try { + TransportClient client = clientFactory.createClient(hostname, port); + return new ExternalShufflePartitionWriter(client, appId, execId, shuffleId, mapId, partitionId); + } catch (Exception e) { + logger.error("Encountered error while creating transport client"); + throw new RuntimeException(e); // what is standard practice here? + } + } + + @Override + public void commitAllPartitions() { + + } + + @Override + public void abort(Exception exception) { + + } + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala new file mode 100644 index 000000000000..0a643a678579 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +trait ShuffleServiceAddressProvider { + + def start(): Unit = {} + + def getShuffleServiceAddresses(): List[(String, Int)] + + def stop(): Unit = {} +} + +private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { + override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] +} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala new file mode 100644 index 000000000000..eaa30110093c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.remote + +import org.apache.spark.SparkConf + +trait ShuffleServiceAddressProviderFactory { + + def canCreate(masterUrl: String): Boolean + + def create(conf: SparkConf): ShuffleServiceAddressProvider + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala new file mode 100644 index 000000000000..39f6a9a3a213 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.cluster.k8s._ +import org.apache.spark.shuffle.remote.ShuffleServiceAddressProvider +import org.apache.spark.util.Utils + +class KubernetesShuffleServiceAddressProvider( + kubernetesClient: KubernetesClient, + pollForPodsExecutor: ScheduledExecutorService, + podLabels: Map[String, String], + namespace: String, + portNumber: Int) + extends ShuffleServiceAddressProvider with Logging { + + // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, + // but we don't need all "in-between" lists of all executor pods, just the latest known list + // when we query in getShuffleServiceAddresses. + + private val podsUpdateLock = new ReentrantReadWriteLock() + + private val shuffleServicePods = mutable.HashMap.empty[String, Pod] + + private var shuffleServicePodsWatch: Watch = _ + private var pollForPodsTask: ScheduledFuture[_] = _ + + override def start(): Unit = { + pollForPods() + pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( + () => pollForPods(), 0, 10, TimeUnit.SECONDS) + shuffleServicePodsWatch = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) + } + + override def stop(): Unit = { + Utils.tryLogNonFatalError { + if (pollForPodsTask != null) { + pollForPodsTask.cancel(false) + } + } + + Utils.tryLogNonFatalError { + if (shuffleServicePodsWatch != null) { + shuffleServicePodsWatch.close() + } + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() + } + } + + override def getShuffleServiceAddresses(): List[(String, Int)] = { + val readLock = podsUpdateLock.readLock() + readLock.lock() + try { + val addresses = shuffleServicePods.values.map(pod => { + (pod.getStatus.getPodIP, portNumber) + }).toList + logInfo(s"Found backup shuffle service addresses at $addresses.") + addresses + } finally { + readLock.unlock() + } + } + + private def pollForPods(): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + val allPods = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava) + .list() + shuffleServicePods.clear() + allPods.getItems.asScala.foreach(updatePod) + } finally { + writeLock.unlock() + } + } + + private def updatePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") + val state = SparkPodState.toState(pod) + state match { + case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => + shuffleServicePods.remove(pod.getMetadata.getName) + case PodRunning(_) => + shuffleServicePods.put(pod.getMetadata.getName, pod) + case _ => + logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") + } + } + + private def deletePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") + shuffleServicePods.remove(pod.getMetadata.getName) + } + + private class PutPodsInCacheWatcher extends Watcher[Pod] { + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + updatePod(pod) + } finally { + writeLock.unlock() + } + } + + override def onClose(e: KubernetesClientException): Unit = {} + } + + private implicit def toRunnable(func: () => Unit): Runnable = { + new Runnable { + override def run(): Unit = func() + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala new file mode 100644 index 000000000000..e742aaba6c85 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle.k8s + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.shuffle.remote._ +import org.apache.spark.util.ThreadUtils + +class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { + override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") + + override def create(conf: SparkConf): ShuffleServiceAddressProvider = { + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + conf, conf.get("spark.master")) + val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( + "poll-shuffle-service-pods", 1) + val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS) + val shuffleServicePodsNamespace = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE) + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + + s" shuffle service must be defined by" + + s" ${KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") + require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") + + val port: Int = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT) + new KubernetesShuffleServiceAddressProvider( + kubernetesClient, + pollForPodsExecutor, + shuffleServiceLabels.toMap, + shuffleServicePodsNamespace.get, + port) + } +} From e3e9d6884ac43619d968591daf4f8ed940b3f811 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Thu, 20 Dec 2018 16:56:01 -0800 Subject: [PATCH 02/30] working version --- .../shuffle/ExternalShuffleBlockHandler.java | 23 ++ .../shuffle/ExternalShuffleBlockResolver.java | 21 +- .../shuffle/FileWriterStreamCallback.java | 150 +++++++++++++ .../k8s/KubernetesExternalShuffleClient.java | 160 ++++++++++++++ .../mesos/MesosExternalShuffleClient.java | 4 +- .../protocol/BlockTransferMessage.java | 8 +- .../protocol/OpenShufflePartition.java | 11 +- .../protocol/{mesos => }/RegisterDriver.java | 3 +- .../RegisterExecutorWithExternal.java | 90 ++++++++ .../{mesos => }/ShuffleServiceHeartbeat.java | 3 +- .../UploadShufflePartitionStream.java | 16 +- .../shuffle/api/ShufflePartitionWriter.java | 1 - .../external/ExternalShuffleDataIO.java | 12 +- ...va => ExternalShufflePartitionWriter.java} | 37 ++-- .../external/ExternalShuffleWriteSupport.java | 31 +-- .../ShuffleServiceAddressProvider.scala | 31 --- ...ShuffleServiceAddressProviderFactory.scala | 28 --- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/UnsafeShuffleWriter.java | 2 +- .../spark/internal/config/package.scala | 3 + .../apache/spark/storage/BlockManager.scala | 38 +++- .../sort/UnsafeShuffleWriterSuite.java | 3 +- .../org/apache/spark/deploy/k8s/Config.scala | 23 ++ .../KubernetesExternalShuffleService.scala | 202 ++++++++++++++++++ ...ernetesShuffleServiceAddressProvider.scala | 146 ------------- ...ShuffleServiceAddressProviderFactory.scala | 48 ----- .../mesos/MesosExternalShuffleService.scala | 3 +- 27 files changed, 768 insertions(+), 331 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/{mesos => }/RegisterDriver.java (94%) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/{mesos => }/ShuffleServiceHeartbeat.java (92%) rename core/src/main/java/org/apache/spark/shuffle/external/{ExternalShuffleMapOutputWriter.java => ExternalShufflePartitionWriter.java} (72%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala delete mode 100644 core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala delete mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 098fa7974b87..59e2229f2db1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -34,6 +34,7 @@ import org.slf4j.LoggerFactory; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; @@ -81,6 +82,22 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb handleMessage(msgObj, client, callback); } + @Override + public StreamCallbackWithID receiveStream( + TransportClient client, + ByteBuffer messageHeader, + RpcResponseCallback callback) { + BlockTransferMessage header = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader); + return handleStream(header, client, callback); + } + + protected StreamCallbackWithID handleStream( + BlockTransferMessage header, + TransportClient client, + RpcResponseCallback callback) { + throw new UnsupportedOperationException("Unexpected message header: " + header); + } + protected void handleMessage( BlockTransferMessage msgObj, TransportClient client, @@ -181,6 +198,10 @@ private class ShuffleMetrics implements MetricSet { private final Timer registerExecutorRequestLatencyMillis = new Timer(); // Block transfer rate in byte per second private final Meter blockTransferRateBytes = new Meter(); + // Partition upload latency in ms + private final Timer uploadPartitionkStreamMillis = new Timer(); + // Partition read latency in ms + private final Timer openPartitionMillis = new Timer(); private ShuffleMetrics() { allMetrics = new HashMap<>(); @@ -189,6 +210,8 @@ private ShuffleMetrics() { allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", (Gauge) () -> blockManager.getRegisteredExecutorsSize()); + allMetrics.put("uploadPartitionkStreamMillis", uploadPartitionkStreamMillis); + allMetrics.put("openPartitionMillis", openPartitionMillis); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369..6b1d879d0618 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -19,13 +19,14 @@ import java.io.*; import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.util.regex.Pattern; import java.util.*; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; import java.util.concurrent.Executors; import java.util.regex.Matcher; -import java.util.regex.Pattern; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; @@ -69,6 +70,9 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); + // TODO: Dont necessarily write to local + private final File shuffleDir; + private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); // Map containing all registered executors' metadata. @@ -92,8 +96,8 @@ public class ExternalShuffleBlockResolver { final DB db; private final List knownManagers = Arrays.asList( - "org.apache.spark.shuffle.sort.SortShuffleManager", - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + "org.apache.spark.shuffle.sort.SortShuffleManager", + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { @@ -131,6 +135,10 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { } else { executors = Maps.newConcurrentMap(); } + + // TODO: Remove local writes + this.shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile(); + this.directoryCleaner = directoryCleaner; } @@ -138,6 +146,7 @@ public int getRegisteredExecutorsSize() { return executors.size(); } + /** Registers a new Executor with all the configuration we need to find its shuffle files. */ public void registerExecutor( String appId, @@ -179,6 +188,8 @@ public ManagedBuffer getBlockData( return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } + + /** * Removes our metadata of all executors registered for the given application, and optionally * also deletes the local directories associated with the executors of that application in a @@ -302,8 +313,8 @@ private ManagedBuffer getSortBasedShuffleBlockData( * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). */ - @VisibleForTesting - static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + + public static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java new file mode 100644 index 000000000000..fb753fed2d18 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -0,0 +1,150 @@ +package org.apache.spark.network.shuffle; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.network.client.StreamCallbackWithID; + +public class FileWriterStreamCallback implements StreamCallbackWithID { + + private static final Logger logger = LoggerFactory.getLogger(FileWriterStreamCallback.class); + + public enum FileType { + DATA("shuffle-data"), + INDEX("shuffle-index"); + + private final String typeString; + + FileType(String typeString) { + this.typeString = typeString; + } + + @Override + public String toString() { + return typeString; + } + } + + private final ExternalShuffleBlockResolver.AppExecId fullExecId; + private final int shuffleId; + private final int mapId; + private final File file; + private final FileType fileType; + private WritableByteChannel fileOutputChannel = null; + + public FileWriterStreamCallback( + ExternalShuffleBlockResolver.AppExecId fullExecId, + int shuffleId, + int mapId, + File file, + FileWriterStreamCallback.FileType fileType) { + this.fullExecId = fullExecId; + this.shuffleId = shuffleId; + this.mapId = mapId; + this.file = file; + this.fileType = fileType; + } + + public void open() { + logger.info( + "Opening {} for backup writing. File type: {}", file.getAbsolutePath(), fileType); + if (fileOutputChannel != null) { + throw new IllegalStateException( + String.format( + "File %s for is already open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + if (!file.exists()) { + try { + if (!file.getParentFile().isDirectory() && !file.getParentFile().mkdirs()) { + throw new IOException( + String.format( + "Failed to create shuffle file directory at" + + file.getParentFile().getAbsolutePath() + "(type: %s).", fileType)); + } + + if (!file.createNewFile()) { + throw new IOException( + String.format( + "Failed to create shuffle file (type: %s).", fileType)); + } + } catch (IOException e) { + throw new RuntimeException( + String.format( + "Failed to create shuffle file at %s for backup (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + try { + // TODO encryption + fileOutputChannel = Channels.newChannel(new FileOutputStream(file)); + } catch (FileNotFoundException e) { + throw new RuntimeException( + String.format( + "Failed to find file for writing at %s (type: %s).", + file.getAbsolutePath(), + fileType), + e); + } + } + + @Override + public String getID() { + return String.format("%s-%s-%d-%d-%s", + fullExecId.appId, + fullExecId.execId, + shuffleId, + mapId, + fileType); + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + verifyShuffleFileOpenForWriting(); + while (buf.hasRemaining()) { + fileOutputChannel.write(buf); + } + } + + @Override + public void onComplete(String streamId) throws IOException { + fileOutputChannel.close(); + } + + @Override + public void onFailure(String streamId, Throwable cause) throws IOException { + logger.warn("Failed to back up shuffle file at {} (type: %s).", + file.getAbsolutePath(), + fileType, + cause); + fileOutputChannel.close(); + // TODO delete parent dirs too + if (!file.delete()) { + logger.warn( + "Failed to delete incomplete backup shuffle file at %s (type: %s)", + file.getAbsolutePath(), + fileType); + } + } + + private void verifyShuffleFileOpenForWriting() { + if (fileOutputChannel == null) { + throw new RuntimeException( + String.format( + "Shuffle file at %s not open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); + } + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java new file mode 100644 index 000000000000..20a72851304b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.k8s; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +import com.google.common.util.concurrent.ThreadFactoryBuilder; +import org.apache.spark.network.shuffle.protocol.RegisterExecutorWithExternal; +import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SecretKeyHolder; +import org.apache.spark.network.shuffle.ExternalShuffleClient; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; +import org.apache.spark.network.util.TransportConf; + +/** + * A client for talking to the external shuffle service in Kubernetes coarse-grained mode. + * + * This is used by the Spark driver to register with each external shuffle service on the cluster. + * The reason why the driver has to talk to the service is for cleaning up shuffle files reliably + * after the application exits. Kubernetes does not provide a great alternative to do this, so Spark + * has to detect this itself. + */ +public class KubernetesExternalShuffleClient extends ExternalShuffleClient { + private static final Logger logger = + LoggerFactory.getLogger(KubernetesExternalShuffleClient.class); + + private final ScheduledExecutorService heartbeaterThread = + Executors.newSingleThreadScheduledExecutor( + new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("kubernetes-external-shuffle-client-heartbeater") + .build()); + + /** + * Creates a Kubernetes external shuffle client that wraps the {@link ExternalShuffleClient}. + * Please refer to docs on {@link ExternalShuffleClient} for more information. + */ + public KubernetesExternalShuffleClient( + TransportConf conf, + SecretKeyHolder secretKeyHolder, + boolean authEnabled, + long registrationTimeoutMs) { + super(conf, secretKeyHolder, authEnabled, registrationTimeoutMs); + } + + public void registerDriverWithShuffleService( + String host, + int port, + long heartbeatTimeoutMs, + long heartbeatIntervalMs) throws IOException, InterruptedException { + + checkInit(); + ByteBuffer registerDriver = new RegisterDriver(appId, heartbeatTimeoutMs).toByteBuffer(); + logger.info("Registering with external shuffle service at " + host + ":" + port); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); + } + + public void registerExecutorWithShuffleService( + String host, + int port, + String appId, + String execId, + String shuffleManager) throws IOException, InterruptedException { + checkInit(); + ByteBuffer registerExecutor = + new RegisterExecutorWithExternal(appId, execId, shuffleManager).toByteBuffer(); + logger.info("Registering with external shuffle service for " + appId + ":" + execId); + TransportClient client = clientFactory.createClient(host, port); + client.sendRpc(registerExecutor, new RegisterExecutorCallback(appId, execId)); + } + + private class RegisterDriverCallback implements RpcResponseCallback { + private final TransportClient client; + private final long heartbeatIntervalMs; + + private RegisterDriverCallback(TransportClient client, long heartbeatIntervalMs) { + this.client = client; + this.heartbeatIntervalMs = heartbeatIntervalMs; + } + + @Override + public void onSuccess(ByteBuffer response) { + heartbeaterThread.scheduleAtFixedRate( + new Heartbeater(client), 0, heartbeatIntervalMs, TimeUnit.MILLISECONDS); + logger.info("Successfully registered app " + appId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register app " + appId + " with external shuffle service. " + + "Please manually remove shuffle data after driver exit. Error: " + e); + } + } + + private class RegisterExecutorCallback implements RpcResponseCallback { + private String appId; + private String execId; + + private RegisterExecutorCallback(String appId, String execId) { + this.appId = appId; + this.execId = execId; + } + + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully registered " + appId + ":" + execId + " with external shuffle service."); + } + + @Override + public void onFailure(Throwable e) { + logger.warn("Unable to register " + appId + ":" + execId + " with external shuffle service, " + e); + } + } + + @Override + public void close() { + heartbeaterThread.shutdownNow(); + super.close(); + } + + private class Heartbeater implements Runnable { + + private final TransportClient client; + + private Heartbeater(TransportClient client) { + this.client = client; + } + + @Override + public void run() { + // TODO: Stop sending heartbeats if the shuffle service has lost the app due to timeout + client.send(new ShuffleServiceHeartbeat(appId).toByteBuffer()); + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java index 60179f126bc4..3510509f20ee 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/mesos/MesosExternalShuffleClient.java @@ -24,7 +24,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; +import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +32,7 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.shuffle.ExternalShuffleClient; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; +import org.apache.spark.network.shuffle.protocol.RegisterDriver; import org.apache.spark.network.util.TransportConf; /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a68a297519b6..8185193061a7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -23,8 +23,6 @@ import io.netty.buffer.Unpooled; import org.apache.spark.network.protocol.Encodable; -import org.apache.spark.network.shuffle.protocol.mesos.RegisterDriver; -import org.apache.spark.network.shuffle.protocol.mesos.ShuffleServiceHeartbeat; /** * Messages handled by the {@link org.apache.spark.network.shuffle.ExternalShuffleBlockHandler}, or @@ -42,7 +40,8 @@ public abstract class BlockTransferMessage implements Encodable { /** Preceding every serialized message is its type, which allows us to deserialize it. */ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), - HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6); + HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), + OPEN_SHUFFLE_PARTITION(8), REGISTER_EXECUTOR_WITH_EXTERNAL(9); private final byte id; @@ -68,6 +67,9 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 4: return RegisterDriver.decode(buf); case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); + case 7: return UploadShufflePartitionStream.decode(buf); + case 8: return OpenShufflePartition.decode(buf); + case 9: return RegisterExecutorWithExternal.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java index 02b09a83cfb7..408be7fad26d 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java @@ -4,6 +4,9 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + public class OpenShufflePartition extends BlockTransferMessage { public final String appId; public final String execId; @@ -11,7 +14,8 @@ public class OpenShufflePartition extends BlockTransferMessage { public final int mapId; public final int partitionId; - public OpenShufflePartition(String appId, String execId, int shuffleId, int mapId, int partitionId) { + public OpenShufflePartition( + String appId, String execId, int shuffleId, int mapId, int partitionId) { this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; @@ -34,7 +38,7 @@ public boolean equals(Object other) { @Override protected Type type() { - return null; + return Type.OPEN_SHUFFLE_PARTITION; } @Override @@ -54,7 +58,8 @@ public String toString() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java similarity index 94% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java index d5f53ccb7f74..516a51ad7cc1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/RegisterDriver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterDriver.java @@ -15,13 +15,12 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import com.google.common.base.Objects; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; // Needed by ScalaDoc. See SPARK-7726 import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java new file mode 100644 index 000000000000..64a2fd77333b --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +public class RegisterExecutorWithExternal extends BlockTransferMessage { + + public final String appId; + public final String execId; + public final String shuffleManager; + + public RegisterExecutorWithExternal( + String appId, String execId, String shuffleManager) { + this.appId = appId; + this.execId = execId; + this.shuffleManager = shuffleManager; + } + + @Override + protected Type type() { + return Type.REGISTER_EXECUTOR_WITH_EXTERNAL; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.Strings.encodedLength(shuffleManager); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.Strings.encode(buf, shuffleManager); + } + + @Override + public boolean equals(Object other) { + if (other instanceof RegisterExecutorWithExternal) { + RegisterExecutorWithExternal o = (RegisterExecutorWithExternal) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Objects.equal(shuffleManager, o.shuffleManager); + } + return false; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, shuffleManager); + } + + @Override + public String toString() { + return Objects.toStringHelper(RegisterExecutorWithExternal.class) + .add("appId", appId) + .add("execId", execId) + .add("shuffleManager", shuffleManager) + .toString(); + } + + public static RegisterExecutorWithExternal decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String shuffleManager = Encoders.Strings.decode(buf); + return new RegisterExecutorWithExternal(appId, execId, shuffleManager); + } +} \ No newline at end of file diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java similarity index 92% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java index b30bb9aed55b..1a6ffc0f9133 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/mesos/ShuffleServiceHeartbeat.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ShuffleServiceHeartbeat.java @@ -15,11 +15,10 @@ * limitations under the License. */ -package org.apache.spark.network.shuffle.protocol.mesos; +package org.apache.spark.network.shuffle.protocol; import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; // Needed by ScalaDoc. See SPARK-7726 import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java index 0a6c78e27c1a..a72d81b84339 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -21,13 +21,15 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + /** * Upload shuffle partition request to the External Shuffle Service. * This request should also include the driverHostPort for the sake of * setting up a driver heartbeat to monitor heartbeat */ public class UploadShufflePartitionStream extends BlockTransferMessage { - public final String driverHostPort; public final String appId; public final String execId; public final int shuffleId; @@ -35,13 +37,11 @@ public class UploadShufflePartitionStream extends BlockTransferMessage { public final int partitionId; public UploadShufflePartitionStream( - String driverHostPort, String appId, String execId, int shuffleId, int mapId, int partitionId) { - this.driverHostPort = driverHostPort; this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; @@ -54,7 +54,6 @@ public boolean equals(Object other) { if (other != null && other instanceof UploadShufflePartitionStream) { UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; return Objects.equal(appId, o.appId) - && driverHostPort == o.driverHostPort && execId == o.execId && shuffleId == o.shuffleId && mapId == o.mapId @@ -70,13 +69,12 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(driverHostPort, appId, execId, shuffleId, mapId, partitionId); + return Objects.hashCode(appId, execId, shuffleId, mapId, partitionId); } @Override public String toString() { return Objects.toStringHelper(this) - .add("driverHostPort", driverHostPort) .add("appId", appId) .add("execId", execId) .add("shuffleId", shuffleId) @@ -86,13 +84,12 @@ public String toString() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(driverHostPort) + Encoders.Strings.encodedLength(appId) + + return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; } @Override public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, driverHostPort); Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); @@ -101,12 +98,11 @@ public void encode(ByteBuf buf) { } public static UploadShufflePartitionStream decode(ByteBuf buf) { - String driverHostPort = Encoders.Strings.decode(buf); String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); int partitionId = buf.readInt(); - return new UploadShufflePartitionStream(driverHostPort, appId, execId, shuffleId, mapId, partitionId); + return new UploadShufflePartitionStream(appId, execId, shuffleId, mapId, partitionId); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 7fa667cf137e..4135458ddd43 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -17,7 +17,6 @@ package org.apache.spark.shuffle.api; -import java.io.InputStream; import java.io.OutputStream; /** diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 3ce83201665a..2b9dbb6b226a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -15,19 +15,24 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private static final String SHUFFLE_SERVICE_PORT_CONFIG = "spark.shuffle.service.port"; private static final String DEFAULT_SHUFFLE_PORT = "7337"; + private static final SparkEnv sparkEnv = SparkEnv.get(); + private final SparkConf sparkConf; private final TransportConf conf; private final SecurityManager securityManager; private final String hostname; private final int port; private final String execId; + private final String driverHostName; public ExternalShuffleDataIO( SparkConf sparkConf) { this.sparkConf = sparkConf; this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 2); - this.securityManager = SparkEnv.get().securityManager(); - this.hostname = SparkEnv.get().blockManager().blockTransferService().hostName(); + + this.securityManager = sparkEnv.securityManager(); + this.hostname = sparkEnv.blockManager().blockTransferService().hostName(); + this.driverHostName = sparkEnv.blockManager().master().driverEndpoint().address().hostPort(); int tmpPort = Integer.parseInt( Utils.getSparkOrYarnConfig(sparkConf, SHUFFLE_SERVICE_PORT_CONFIG, DEFAULT_SHUFFLE_PORT)); @@ -53,6 +58,7 @@ public ShuffleReadSupport readSupport() { @Override public ShuffleWriteSupport writeSupport() { return new ExternalShuffleWriteSupport( - conf, securityManager.isAuthenticationEnabled(), securityManager, hostname, port, execId); + conf, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port, execId, driverHostName); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java similarity index 72% rename from core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java rename to core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index c75b0b9dbc56..7c010dacc38c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -1,5 +1,6 @@ package org.apache.spark.shuffle.external; +import org.apache.hadoop.hive.serde2.ByteStream; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; @@ -9,8 +10,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; import java.nio.ByteBuffer; public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { @@ -23,8 +23,10 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private final int shuffleId; private final int mapId; private final int partitionId; + private final String driverHostPort; private long totalLength = 0; + private final ByteArrayOutputStream partitionBuffer = new ByteArrayOutputStream(); public ExternalShufflePartitionWriter( TransportClient client, @@ -32,17 +34,24 @@ public ExternalShufflePartitionWriter( String execId, int shuffleId, int mapId, - int partitionId) { + int partitionId, + String driverHostPort) { this.client = client; this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; + this.driverHostPort = driverHostPort; } @Override - public void appendBytesToPartition(InputStream streamReadingBytesToAppend) { + public OutputStream openPartitionStream() { + return partitionBuffer; + } + + @Override + public long commitAndGetTotalLength() { RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -56,27 +65,27 @@ public void onFailure(Throwable e) { }; try { ByteBuffer streamHeader = - new UploadShufflePartitionStream(this.appId, execId, shuffleId, mapId, partitionId).toByteBuffer(); - int avaibleSize = streamReadingBytesToAppend.available(); - byte[] buf = new byte[avaibleSize]; - int size = streamReadingBytesToAppend.read(buf, 0, avaibleSize); - assert size == avaibleSize; + new UploadShufflePartitionStream( + this.appId, execId, shuffleId, mapId, partitionId, driverHostPort).toByteBuffer(); + int size = partitionBuffer.size(); + byte[] buf = partitionBuffer.toByteArray(); + ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); totalLength += size; } catch (Exception e) { logger.error("Encountered error while attempting to upload partition to ESS", e); + client.close(); throw new RuntimeException(e); + } finally { + logger.info("Successfully sent partition to ESS"); + client.close(); } - } - - @Override - public long commitAndGetTotalLength() { return totalLength; } @Override public void abort(Exception failureReason) { - // TODO + logger.error("Encountered error while attempting to upload partition to ESS", failureReason); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 102268371cd0..388628e9d24d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -27,32 +27,18 @@ public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { private final String hostname; private final int port; private final String execId; + private final String driverHostPort; public ExternalShuffleWriteSupport( - TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, String hostname, int port, String execId) { + TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, + String hostname, int port, String execId, String driverHostPort) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostname = hostname; this.port = port; this.execId = execId; - } - - @Override - public ShufflePartitionWriter newPartitionWriter(String appId, int shuffleId, int mapId, int partitionId) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); - } - TransportClientFactory clientFactory = context.createClientFactory(bootstraps); - try { - TransportClient client = clientFactory.createClient(hostname, port); - return new ExternalShufflePartitionWriter(client, appId, execId, shuffleId, mapId, partitionId); - } catch (Exception e) { - logger.error("Encountered error while creating transport client"); - throw new RuntimeException(e); // what is standard practice here? - } + this.driverHostPort = driverHostPort; } @Override @@ -68,7 +54,8 @@ public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, in public ShufflePartitionWriter newPartitionWriter(int partitionId) { try { TransportClient client = clientFactory.createClient(hostname, port); - return new ExternalShufflePartitionWriter(client, appId, execId, shuffleId, mapId, partitionId); + return new ExternalShufflePartitionWriter( + client, appId, execId, shuffleId, mapId, partitionId, driverHostPort); } catch (Exception e) { logger.error("Encountered error while creating transport client"); throw new RuntimeException(e); // what is standard practice here? @@ -77,13 +64,13 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { @Override public void commitAllPartitions() { - + logger.info("Commiting all partitions"); } @Override public void abort(Exception exception) { - + logger.error("Encountered error while attempting to all partitions to ESS", exception); } - } + }; } } diff --git a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala deleted file mode 100644 index 0a643a678579..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProvider.scala +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.remote - -trait ShuffleServiceAddressProvider { - - def start(): Unit = {} - - def getShuffleServiceAddresses(): List[(String, Int)] - - def stop(): Unit = {} -} - -private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { - override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] -} \ No newline at end of file diff --git a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala b/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala deleted file mode 100644 index eaa30110093c..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/remote/ShuffleServiceAddressProviderFactory.scala +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.remote - -import org.apache.spark.SparkConf - -trait ShuffleServiceAddressProviderFactory { - - def canCreate(masterUrl: String): Boolean - - def create(conf: SparkConf): ShuffleServiceAddressProvider - -} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index c683b2854b17..823c36d051dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -239,7 +239,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio final long writeStartTime = System.nanoTime(); ShuffleMapOutputWriter mapOutputWriter = pluggableWriteSupport.newMapOutputWriter( - appId, shuffleId, mapId); + appId, shuffleId, mapId); try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 150a783aa87f..7336a75af123 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -85,7 +85,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { private final int initialSortBufferSize; private final int inputBufferSizeInBytes; private final int outputBufferSizeInBytes; - private final ShuffleWriteSupport pluggableWriteSupport; // TODO initialize + private final ShuffleWriteSupport pluggableWriteSupport; @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index bede012e3397..533672f809f4 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -182,6 +182,9 @@ package object config { private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val K8S_SHUFFLE_SERVICE_ENABLED = + ConfigBuilder("spark.k8s.shuffle.service.enabled").booleanConf.createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PORT = ConfigBuilder("spark.shuffle.service.port").intConf.createWithDefault(7337) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1dfbc6effb34..9db290f9ba2c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -44,6 +44,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.k8s.KubernetesExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv @@ -130,15 +131,22 @@ private[spark] class BlockManager( numUsableCores: Int) extends BlockDataManager with BlockEvictionHandler with Logging { - private[spark] val externalShuffleServiceEnabled = + private[spark] val externalNonK8sShuffleService = conf.get(config.SHUFFLE_SERVICE_ENABLED) + + private[spark] val externalk8sShuffleServiceEnabled = + conf.get(config.K8S_SHUFFLE_SERVICE_ENABLED) + + private[spark] val externalShuffleServiceEnabled = + externalNonK8sShuffleService || externalk8sShuffleServiceEnabled + private val remoteReadNioBufferConversion = conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) val diskBlockManager = { // Only perform cleanup if an external service is not serving our shuffle files. - val deleteFilesOnStop = - !externalShuffleServiceEnabled || executorId == SparkContext.DRIVER_IDENTIFIER + val deleteFilesOnStop = !externalShuffleServiceEnabled || + executorId == SparkContext.DRIVER_IDENTIFIER new DiskBlockManager(conf, deleteFilesOnStop) } @@ -184,7 +192,11 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. - private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { + private[spark] val shuffleClient = if (externalk8sShuffleServiceEnabled) { + val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) + new KubernetesExternalShuffleClient(transConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) + } else if (externalNonK8sShuffleService) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) @@ -252,6 +264,7 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id + // TODO: Customize so that the shuffleServiceID is pointing to K8s shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) @@ -259,8 +272,21 @@ private[spark] class BlockManager( blockManagerId } - // Register Executors' configuration with the local shuffle service, if one should exist. - if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { + if (externalk8sShuffleServiceEnabled && blockManagerId.isDriver) { + // Register Drivers' configuration with the k8s shuffle service + shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] + .registerDriverWithShuffleService( + shuffleServerId.host, shuffleServerId.port, + conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), + conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) + } else if (externalk8sShuffleServiceEnabled && !blockManagerId.isDriver) { + shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] + .registerExecutorWithShuffleService( + shuffleServerId.host, shuffleServerId.port, appId, + shuffleServerId.executorId, shuffleManager.getClass.getName) + } else if (externalNonK8sShuffleService && !blockManagerId.isDriver) { + // Register Executors' configuration with the local shuffle service, if one should exist. registerWithExternalShuffleServer() } diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 360c1769ad31..18d8e09589c8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -644,7 +644,8 @@ public void testPeakMemoryUsed() throws Exception { private final class TestShuffleWriteSupport implements ShuffleWriteSupport { @Override - public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + public ShuffleMapOutputWriter newMapOutputWriter( + String appId, int shuffleId, int mapId) { try { if (!mergedOutputFile.exists() && !mergedOutputFile.createNewFile()) { throw new IllegalStateException( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index e8bf16df190e..b2c67ba7f920 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -289,6 +289,26 @@ private[spark] object Config extends Logging { .booleanConf .createWithDefault(true) + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE = + ConfigBuilder("spark.kubernetes.shuffle.service.remote.pods.namespace") + .doc("Namespace of the pods that are running the shuffle service instances for backing up" + + " shuffle data.") + .stringConf + .createOptional + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT = + ConfigBuilder("spark.kubernetes.shuffle.service.remote.port") + .doc("Port of the shuffle services that will back up the application's shuffle data.") + .intConf + .createWithDefault(7337) + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL = + ConfigBuilder("spark.kubernetes.shuffle.service.cleanup.interval") + .doc("Cleanup interval for the shuffle service to take down an app id") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("30s") + + val KUBERNETES_DRIVER_LABEL_PREFIX = "spark.kubernetes.driver.label." val KUBERNETES_DRIVER_ANNOTATION_PREFIX = "spark.kubernetes.driver.annotation." val KUBERNETES_DRIVER_SECRETS_PREFIX = "spark.kubernetes.driver.secrets." @@ -313,4 +333,7 @@ private[spark] object Config extends Logging { val KUBERNETES_VOLUMES_OPTIONS_SIZE_LIMIT_KEY = "options.sizeLimit" val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." + + val KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS = + "spark.kubernetes.shuffle.service.remote.label." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala new file mode 100644 index 000000000000..6efe73cb659d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import java.nio.ByteBuffer +import java.nio.file.Files +import java.nio.file.Paths +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} + +import scala.collection.JavaConverters._ + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.ExternalShuffleService +import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL +import org.apache.spark.internal.Logging +import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} +import org.apache.spark.network.shuffle._ +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver._ +import org.apache.spark.network.shuffle.protocol._ +import org.apache.spark.network.util.{JavaUtils, TransportConf} +import org.apache.spark.util.ThreadUtils + +/** + * An RPC endpoint that receives registration requests from Spark drivers running on Kubernetes. + * It detects driver termination and calls the cleanup callback to [[ExternalShuffleService]]. + */ +private[spark] class KubernetesExternalShuffleBlockHandler( + transportConf: TransportConf, + cleanerIntervalS: Long) + extends ExternalShuffleBlockHandler(transportConf, null) with Logging { + + ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") + .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS) + + // Stores a map of app id to app state (timeout value and last heartbeat) + private val connectedApps = new ConcurrentHashMap[String, AppState]() + private val registeredExecutors = new ConcurrentHashMap[AppExecId, ExecutorShuffleInfo]() + private val knownManagers = Array( + "org.apache.spark.shuffle.sort.SortShuffleManager", + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + private final val shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile() + + protected override def handleMessage( + message: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): Unit = { + message match { + case RegisterExecutorParam(appId, execId, shuffleManager) => + val fullId = new AppExecId(appId, execId) + if (registeredExecutors.containsKey(fullId)) { + throw new UnsupportedOperationException(s"Executor $fullId cannot be registered twice") + } + val executorDir = Paths.get(shuffleDir.getAbsolutePath, appId, execId).toFile + if (!executorDir.mkdir()) { + throw new RuntimeException(s"Failed to create dir ${executorDir.getAbsolutePath}") + } + if (!knownManagers.contains(shuffleManager)) { + throw new UnsupportedOperationException(s"Unsupported shuffle manager of exec: ${fullId}") + } + val executorShuffleInfo = new ExecutorShuffleInfo( + Array(executorDir.getAbsolutePath), 1, shuffleManager) + logInfo(s"Registering executor ${fullId} with ${executorShuffleInfo}") + registeredExecutors.put(fullId, executorShuffleInfo) + + case RegisterDriverParam(appId, appState) => + val address = client.getSocketAddress + val timeout = appState.heartbeatTimeout + logInfo(s"Received registration request from app $appId (remote address $address, " + + s"heartbeat timeout $timeout ms).") + if (connectedApps.containsKey(appId)) { + logWarning(s"Received a registration request from app $appId, but it was already " + + s"registered") + } + connectedApps.put(appId, appState) + callback.onSuccess(ByteBuffer.allocate(0)) + + case Heartbeat(appId) => + val address = client.getSocketAddress + Option(connectedApps.get(appId)) match { + case Some(existingAppState) => + logTrace(s"Received ShuffleServiceHeartbeat from app '$appId' (remote " + + s"address $address).") + existingAppState.lastHeartbeat = System.nanoTime() + case None => + logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + + s"address $address, appId '$appId').") + } + case _ => super.handleMessage(message, client, callback) + } + } + + protected override def handleStream( + header: BlockTransferMessage, + client: TransportClient, + callback: RpcResponseCallback): StreamCallbackWithID = { + header match { + case UploadParam( + appId, execId, shuffleId, mapId, partitionId) => + getFileWriterStreamCallback( + appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + case _ => super.handleStream(header, client, callback) + } + } + + private def getFileWriterStreamCallback( + appId: String, + execId: String, + shuffleId: Int, + mapId: Int, + extension: String, + fileType: FileWriterStreamCallback.FileType): StreamCallbackWithID = { + val fullId = new AppExecId(appId, execId) + val executor = registeredExecutors.get(fullId) + if (executor == null) { + throw new RuntimeException( + s"Executor is not registered for remote shuffle (appId=$appId, execId=$execId)") + } + val backedUpFile = + ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0." + extension) + val streamCallback = + new FileWriterStreamCallback(fullId, shuffleId, mapId, backedUpFile, fileType) + streamCallback.open() + streamCallback + } + + /** An extractor object for matching BlockTransferMessages. */ + private object RegisterDriverParam { + def unapply(r: RegisterDriver): Option[(String, AppState)] = + Some((r.getAppId, new AppState(r.getHeartbeatTimeoutMs, System.nanoTime()))) + } + + private object Heartbeat { + def unapply(h: ShuffleServiceHeartbeat): Option[String] = Some(h.getAppId) + } + + private object UploadParam { + def unapply(u: UploadShufflePartitionStream): Option[(String, String, Int, Int, Int)] = + Some((u.appId, u.execId, u.shuffleId, u.mapId, u.partitionId)) + } + + private object RegisterExecutorParam { + def unapply(e: RegisterExecutorWithExternal): Option[(String, String, String)] = + Some((e.appId, e.execId, e.shuffleManager)) + } + + private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) + + private class CleanerThread extends Runnable { + override def run(): Unit = { + val now = System.nanoTime() + connectedApps.asScala.foreach { case (appId, appState) => + if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { + logInfo(s"Application $appId timed out. Removing shuffle files.") + connectedApps.remove(appId) + applicationRemoved(appId, true) + } + } + } + } +} + +/** + * A wrapper of [[ExternalShuffleService]] that provides an additional endpoint for drivers + * to associate with. This allows the shuffle service to detect when a driver is terminated + * and can clean up the associated shuffle files. + */ +private[spark] class KubernetesExternalShuffleService( + conf: SparkConf, securityManager: SecurityManager) + extends ExternalShuffleService(conf, securityManager) { + + protected override def newShuffleBlockHandler( + conf: TransportConf): ExternalShuffleBlockHandler = { + val cleanerIntervalS = this.conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) + new KubernetesExternalShuffleBlockHandler(conf, cleanerIntervalS) + } +} + +private[spark] object KubernetesExternalShuffleService extends Logging { + + def main(args: Array[String]): Unit = { + ExternalShuffleService.main(args, + (conf: SparkConf, sm: SecurityManager) => new KubernetesExternalShuffleService(conf, sm)) + } +} + + diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala deleted file mode 100644 index 39f6a9a3a213..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.k8s - -import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} -import java.util.concurrent.locks.ReentrantReadWriteLock - -import io.fabric8.kubernetes.api.model.Pod -import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} -import scala.collection.JavaConverters._ -import scala.collection.mutable - -import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.cluster.k8s._ -import org.apache.spark.shuffle.remote.ShuffleServiceAddressProvider -import org.apache.spark.util.Utils - -class KubernetesShuffleServiceAddressProvider( - kubernetesClient: KubernetesClient, - pollForPodsExecutor: ScheduledExecutorService, - podLabels: Map[String, String], - namespace: String, - portNumber: Int) - extends ShuffleServiceAddressProvider with Logging { - - // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, - // but we don't need all "in-between" lists of all executor pods, just the latest known list - // when we query in getShuffleServiceAddresses. - - private val podsUpdateLock = new ReentrantReadWriteLock() - - private val shuffleServicePods = mutable.HashMap.empty[String, Pod] - - private var shuffleServicePodsWatch: Watch = _ - private var pollForPodsTask: ScheduledFuture[_] = _ - - override def start(): Unit = { - pollForPods() - pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( - () => pollForPods(), 0, 10, TimeUnit.SECONDS) - shuffleServicePodsWatch = kubernetesClient - .pods() - .inNamespace(namespace) - .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) - } - - override def stop(): Unit = { - Utils.tryLogNonFatalError { - if (pollForPodsTask != null) { - pollForPodsTask.cancel(false) - } - } - - Utils.tryLogNonFatalError { - if (shuffleServicePodsWatch != null) { - shuffleServicePodsWatch.close() - } - } - - Utils.tryLogNonFatalError { - kubernetesClient.close() - } - } - - override def getShuffleServiceAddresses(): List[(String, Int)] = { - val readLock = podsUpdateLock.readLock() - readLock.lock() - try { - val addresses = shuffleServicePods.values.map(pod => { - (pod.getStatus.getPodIP, portNumber) - }).toList - logInfo(s"Found backup shuffle service addresses at $addresses.") - addresses - } finally { - readLock.unlock() - } - } - - private def pollForPods(): Unit = { - val writeLock = podsUpdateLock.writeLock() - writeLock.lock() - try { - val allPods = kubernetesClient - .pods() - .inNamespace(namespace) - .withLabels(podLabels.asJava) - .list() - shuffleServicePods.clear() - allPods.getItems.asScala.foreach(updatePod) - } finally { - writeLock.unlock() - } - } - - private def updatePod(pod: Pod): Unit = { - require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") - val state = SparkPodState.toState(pod) - state match { - case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => - shuffleServicePods.remove(pod.getMetadata.getName) - case PodRunning(_) => - shuffleServicePods.put(pod.getMetadata.getName, pod) - case _ => - logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") - } - } - - private def deletePod(pod: Pod): Unit = { - require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") - shuffleServicePods.remove(pod.getMetadata.getName) - } - - private class PutPodsInCacheWatcher extends Watcher[Pod] { - override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { - val writeLock = podsUpdateLock.writeLock() - writeLock.lock() - try { - updatePod(pod) - } finally { - writeLock.unlock() - } - } - - override def onClose(e: KubernetesClientException): Unit = {} - } - - private implicit def toRunnable(func: () => Unit): Runnable = { - new Runnable { - override def run(): Unit = func() - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala deleted file mode 100644 index e742aaba6c85..000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.shuffle.k8s - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.shuffle.remote._ -import org.apache.spark.util.ThreadUtils - -class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { - override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") - - override def create(conf: SparkConf): ShuffleServiceAddressProvider = { - val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( - conf, conf.get("spark.master")) - val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( - "poll-shuffle-service-pods", 1) - val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS) - val shuffleServicePodsNamespace = conf.get(KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE) - require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + - s" shuffle service must be defined by" + - s" ${KUBERNETES_BACKUP_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") - require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") - - val port: Int = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT) - new KubernetesShuffleServiceAddressProvider( - kubernetesClient, - pollForPodsExecutor, - shuffleServiceLabels.toMap, - shuffleServicePodsNamespace.get, - port) - } -} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 859aa836a315..6d94b9efd1d2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -28,8 +28,7 @@ import org.apache.spark.deploy.mesos.config._ import org.apache.spark.internal.Logging import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage -import org.apache.spark.network.shuffle.protocol.mesos.{RegisterDriver, ShuffleServiceHeartbeat} +import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterDriver, ShuffleServiceHeartbeat} import org.apache.spark.network.util.TransportConf import org.apache.spark.util.ThreadUtils From 458b2beeae412d31e61403031c2f7dea4770af5f Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 26 Dec 2018 12:52:39 -0500 Subject: [PATCH 03/30] added shuffle location discovery --- .../shuffle/ExternalShuffleClient.java | 3 +- .../shuffle/FileWriterStreamCallback.java | 2 +- .../k8s/KubernetesExternalShuffleClient.java | 6 +- .../RegisterExecutorWithExternal.java | 2 +- .../shuffle/api/ShufflePartitionWriter.java | 9 +- .../external/ExternalShuffleDataIO.java | 12 +- .../ExternalShufflePartitionReader.java | 30 ++-- .../ExternalShufflePartitionWriter.java | 17 +- .../external/ExternalShuffleWriteSupport.java | 9 +- .../shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../org/apache/spark/MapOutputTracker.scala | 14 +- .../scala/org/apache/spark/SparkEnv.scala | 24 ++- .../ShuffleServiceAddressProvider.scala | 25 +-- ...ShuffleServiceAddressProviderFactory.scala | 25 +++ .../apache/spark/storage/BlockManager.scala | 16 +- .../sort/UnsafeShuffleWriterSuite.java | 7 +- .../KubernetesExternalShuffleService.scala | 2 +- .../k8s/SparkKubernetesClientFactory.scala | 31 ++++ .../cluster/k8s/ExecutorPodsSnapshot.scala | 8 +- .../k8s/KubernetesClusterManager.scala | 28 +--- .../cluster/k8s/SparkPodStates.scala | 67 ++++++++ ...ernetesShuffleServiceAddressProvider.scala | 148 ++++++++++++++++++ ...ShuffleServiceAddressProviderFactory.scala | 52 ++++++ 23 files changed, 439 insertions(+), 101 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala => core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala (60%) create mode 100644 core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala create mode 100644 resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa7..2a013f5497a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -140,7 +140,8 @@ public void registerWithShuffleServer( ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { checkInit(); try (TransportClient client = clientFactory.createUnmanagedClient(host, port)) { - ByteBuffer registerMessage = new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); + ByteBuffer registerMessage = + new RegisterExecutor(appId, execId, executorInfo).toByteBuffer(); client.sendRpcSync(registerMessage, registrationTimeoutMs); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java index fb753fed2d18..6ca1292efb6b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -147,4 +147,4 @@ private void verifyShuffleFileOpenForWriting() { fileType)); } } -} \ No newline at end of file +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java index 20a72851304b..b145c0d5e8bd 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java @@ -128,12 +128,14 @@ private RegisterExecutorCallback(String appId, String execId) { @Override public void onSuccess(ByteBuffer response) { - logger.info("Successfully registered " + appId + ":" + execId + " with external shuffle service."); + logger.info("Successfully registered " + + appId + ":" + execId + " with external shuffle service."); } @Override public void onFailure(Throwable e) { - logger.warn("Unable to register " + appId + ":" + execId + " with external shuffle service, " + e); + logger.warn("Unable to register " + + appId + ":" + execId + " with external shuffle service, " + e); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java index 64a2fd77333b..39bfb95b4af3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java @@ -87,4 +87,4 @@ public static RegisterExecutorWithExternal decode(ByteBuf buf) { String shuffleManager = Encoders.Strings.decode(buf); return new RegisterExecutorWithExternal(appId, execId, shuffleManager); } -} \ No newline at end of file +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index 4135458ddd43..ae9ada03e760 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -30,10 +30,11 @@ public interface ShufflePartitionWriter { OutputStream openPartitionStream(); /** - * Indicate that the partition was written successfully and there are no more incoming bytes. Returns - * the length of the partition that is written. Note that returning the length is mainly for backwards - * compatibility and should be removed in a more polished variant. After this method is called, the writer - * will be discarded; it's expected that the implementation will close any underlying resources. + * Indicate that the partition was written successfully and there are no more incoming bytes. + * Returns the length of the partition that is written. Note that returning the length is + * mainly for backwards compatibility and should be removed in a more polished variant. + * After this method is called, the writer will be discarded; it's expected that the + * implementation will close any underlying resources. */ long commitAndGetTotalLength(); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 2b9dbb6b226a..cf79e845a9cc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -23,7 +23,6 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private final String hostname; private final int port; private final String execId; - private final String driverHostName; public ExternalShuffleDataIO( SparkConf sparkConf) { @@ -32,15 +31,15 @@ public ExternalShuffleDataIO( this.securityManager = sparkEnv.securityManager(); this.hostname = sparkEnv.blockManager().blockTransferService().hostName(); - this.driverHostName = sparkEnv.blockManager().master().driverEndpoint().address().hostPort(); - int tmpPort = Integer.parseInt( - Utils.getSparkOrYarnConfig(sparkConf, SHUFFLE_SERVICE_PORT_CONFIG, DEFAULT_SHUFFLE_PORT)); + int tmpPort = Integer.parseInt(Utils.getSparkOrYarnConfig( + sparkConf, SHUFFLE_SERVICE_PORT_CONFIG, DEFAULT_SHUFFLE_PORT)); if (tmpPort == 0) { this.port = Integer.parseInt(sparkConf.get(SHUFFLE_SERVICE_PORT_CONFIG)); } else { this.port = tmpPort; } + this.execId = SparkEnv.get().blockManager().shuffleServerId().executorId(); } @@ -52,13 +51,14 @@ public void initialize() { @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, securityManager.isAuthenticationEnabled(), securityManager, hostname, port, execId); + conf, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port, execId); } @Override public ShuffleWriteSupport writeSupport() { return new ExternalShuffleWriteSupport( conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port, execId, driverHostName); + securityManager, hostname, port, execId); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index 75630566b1da..9d6d9cd1f432 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -1,6 +1,5 @@ package org.apache.spark.shuffle.external; -import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; @@ -20,7 +19,8 @@ public class ExternalShufflePartitionReader implements ShufflePartitionReader { - private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionReader.class); + private static final Logger logger = + LoggerFactory.getLogger(ExternalShufflePartitionReader.class); private final TransportClient client; private final String appId; @@ -28,7 +28,12 @@ public class ExternalShufflePartitionReader implements ShufflePartitionReader { private final int shuffleId; private final int mapId; - public ExternalShufflePartitionReader(TransportClient client, String appId, String execId, int shuffleId, int mapId) { + public ExternalShufflePartitionReader( + TransportClient client, + String appId, + String execId, + int shuffleId, + int mapId) { this.client = client; this.appId = appId; this.execId = execId; @@ -38,16 +43,19 @@ public ExternalShufflePartitionReader(TransportClient client, String appId, Stri @Override public InputStream fetchPartition(int reduceId) { - OpenShufflePartition openMessage = new OpenShufflePartition(appId, execId, shuffleId, mapId, reduceId); + OpenShufflePartition openMessage = + new OpenShufflePartition(appId, execId, shuffleId, mapId, reduceId); - ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000 /* what should be the default? */); + ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); try { StreamCombiningCallback callback = new StreamCombiningCallback(); - StreamHandle streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); + StreamHandle streamHandle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); for (int i = 0; i < streamHandle.numChunks; i++) { - client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - callback); + client.stream( + OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), + callback); } return callback.getCombinedInputStream(); } catch (Exception e) { @@ -59,9 +67,9 @@ public InputStream fetchPartition(int reduceId) { private class StreamCombiningCallback implements StreamCallback { public boolean failed; - public final Vector inputStreams; + private final Vector inputStreams; - public StreamCombiningCallback() { + private StreamCombiningCallback() { inputStreams = new Vector<>(); failed = false; } @@ -84,7 +92,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { } } - public SequenceInputStream getCombinedInputStream() { + private SequenceInputStream getCombinedInputStream() { if (failed) { throw new RuntimeException("Stream chunk gathering failed"); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 7c010dacc38c..61cda892dcdb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -1,6 +1,5 @@ package org.apache.spark.shuffle.external; -import org.apache.hadoop.hive.serde2.ByteStream; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; @@ -15,7 +14,8 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { - private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); + private static final Logger logger = + LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); private final TransportClient client; private final String appId; @@ -23,7 +23,6 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private final int shuffleId; private final int mapId; private final int partitionId; - private final String driverHostPort; private long totalLength = 0; private final ByteArrayOutputStream partitionBuffer = new ByteArrayOutputStream(); @@ -34,15 +33,13 @@ public ExternalShufflePartitionWriter( String execId, int shuffleId, int mapId, - int partitionId, - String driverHostPort) { + int partitionId) { this.client = client; this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; - this.driverHostPort = driverHostPort; } @Override @@ -65,8 +62,9 @@ public void onFailure(Throwable e) { }; try { ByteBuffer streamHeader = - new UploadShufflePartitionStream( - this.appId, execId, shuffleId, mapId, partitionId, driverHostPort).toByteBuffer(); + new UploadShufflePartitionStream( + this.appId, execId, shuffleId, mapId, + partitionId).toByteBuffer(); int size = partitionBuffer.size(); byte[] buf = partitionBuffer.toByteArray(); @@ -86,6 +84,7 @@ public void onFailure(Throwable e) { @Override public void abort(Exception failureReason) { - logger.error("Encountered error while attempting to upload partition to ESS", failureReason); + logger.error("Encountered error while attempting" + + "to upload partition to ESS", failureReason); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 388628e9d24d..60ae1c9fbe6d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -27,18 +27,16 @@ public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { private final String hostname; private final int port; private final String execId; - private final String driverHostPort; public ExternalShuffleWriteSupport( TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostname, int port, String execId, String driverHostPort) { + String hostname, int port, String execId) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostname = hostname; this.port = port; this.execId = execId; - this.driverHostPort = driverHostPort; } @Override @@ -55,7 +53,7 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { try { TransportClient client = clientFactory.createClient(hostname, port); return new ExternalShufflePartitionWriter( - client, appId, execId, shuffleId, mapId, partitionId, driverHostPort); + client, appId, execId, shuffleId, mapId, partitionId); } catch (Exception e) { logger.error("Encountered error while creating transport client"); throw new RuntimeException(e); // what is standard practice here? @@ -69,7 +67,8 @@ public void commitAllPartitions() { @Override public void abort(Exception exception) { - logger.error("Encountered error while attempting to all partitions to ESS", exception); + logger.error("Encountered error while" + + "attempting to all partitions to ESS", exception); } }; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 7336a75af123..32be62009511 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -542,7 +542,8 @@ private long[] mergeSpillsWithPluggableWriter( partitionInputStream = blockManager.serializerManager().wrapForEncryption( partitionInputStream); if (compressionCodec != null) { - partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); + partitionInputStream = + compressionCodec.compressedInputStream(partitionInputStream); } Utils.copyStream(partitionInputStream, partitionOutput, false, false); } finally { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 1c4fa4bc6541..fb587f02256e 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,6 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -33,7 +34,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle.MetadataFetchFailedException +import org.apache.spark.shuffle._ import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -213,6 +214,7 @@ private[spark] sealed trait MapOutputTrackerMessage private[spark] case class GetMapOutputStatuses(shuffleId: Int) extends MapOutputTrackerMessage private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage +private[spark] case object GetRemoteShuffleServiceAddresses extends MapOutputTrackerMessage private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext) @@ -233,6 +235,9 @@ private[spark] class MapOutputTrackerMasterEndpoint( logInfo("MapOutputTrackerMasterEndpoint stopped!") context.reply(true) stop() + + case GetRemoteShuffleServiceAddresses => + context.reply(tracker.getRemoteShuffleServiceAddresses) } } @@ -318,7 +323,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging private[spark] class MapOutputTrackerMaster( conf: SparkConf, broadcastManager: BroadcastManager, - isLocal: Boolean) + isLocal: Boolean, + shuffleServiceAddressProvider: ShuffleServiceAddressProvider + = DefaultShuffleServiceAddressProvider) extends MapOutputTracker(conf) { // The size at which we use Broadcast to send the map output statuses to the executors @@ -644,6 +651,9 @@ private[spark] class MapOutputTrackerMaster( } } + def getRemoteShuffleServiceAddresses: List[(String, Int)] = + shuffleServiceAddressProvider.getShuffleServiceAddresses() + // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 66038eeaea54..68ef2ed98107 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,13 +19,13 @@ package org.apache.spark import java.io.File import java.net.Socket -import java.util.Locale +import java.util.{Locale, ServiceLoader} +import com.google.common.collect.MapMaker +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Properties -import com.google.common.collect.MapMaker - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager @@ -39,7 +39,7 @@ import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -302,7 +302,21 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) + val loader = Utils.getContextOrSparkClassLoader + val master = conf.get("spark.master") + val serviceLoaders = + ServiceLoader.load(classOf[ShuffleServiceAddressProviderFactory], loader) + .asScala.filter(_.canCreate(conf.get("spark.master"))) + if (serviceLoaders.size > 1) { + throw new SparkException( + s"Multiple external cluster managers registered for the url $master: $serviceLoaders") + } + val shuffleServiceAddressProvider = serviceLoaders.headOption + .map(_.create(conf)) + .getOrElse(DefaultShuffleServiceAddressProvider) + shuffleServiceAddressProvider.start() + + new MapOutputTrackerMaster(conf, broadcastManager, isLocal, shuffleServiceAddressProvider) } else { new MapOutputTrackerWorker(conf) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala similarity index 60% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala rename to core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala index 83daddf71448..96d529872b30 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodStates.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProvider.scala @@ -14,24 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.scheduler.cluster.k8s -import io.fabric8.kubernetes.api.model.Pod +package org.apache.spark.shuffle -sealed trait ExecutorPodState { - def pod: Pod +trait ShuffleServiceAddressProvider { + def start(): Unit = {} + def getShuffleServiceAddresses(): List[(String, Int)] + def stop(): Unit = {} } -case class PodRunning(pod: Pod) extends ExecutorPodState - -case class PodPending(pod: Pod) extends ExecutorPodState - -sealed trait FinalPodState extends ExecutorPodState - -case class PodSucceeded(pod: Pod) extends FinalPodState - -case class PodFailed(pod: Pod) extends FinalPodState - -case class PodDeleted(pod: Pod) extends FinalPodState - -case class PodUnknown(pod: Pod) extends ExecutorPodState +private[spark] object DefaultShuffleServiceAddressProvider extends ShuffleServiceAddressProvider { + override def getShuffleServiceAddresses(): List[(String, Int)] = List.empty[(String, Int)] +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala new file mode 100644 index 000000000000..68adb8e44585 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf + +trait ShuffleServiceAddressProviderFactory { + def canCreate(masterUrl: String): Boolean + def create(conf: SparkConf): ShuffleServiceAddressProvider +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 9db290f9ba2c..2a935bc2fee3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -184,6 +184,8 @@ private[spark] class BlockManager( } } + private var remoteShuffleServiceAddress: Option[(String, Int)] = None + var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external @@ -264,8 +266,18 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id - // TODO: Customize so that the shuffleServiceID is pointing to K8s - shuffleServerId = if (externalShuffleServiceEnabled) { + if (!blockManagerId.isDriver && externalk8sShuffleServiceEnabled) { + remoteShuffleServiceAddress = Random.shuffle(mapOutputTracker + .trackerEndpoint + .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses)) + .headOption + } + + shuffleServerId = if (externalk8sShuffleServiceEnabled && !blockManagerId.isDriver) { + val (hostName, port) = remoteShuffleServiceAddress.getOrElse( + throw new SparkException("No K8S External Shuffle Addresses")) + BlockManagerId(executorId, hostName, port) + } else if (externalNonK8sShuffleService) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 18d8e09589c8..539336cd4fd8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,7 +23,6 @@ import java.nio.file.StandardOpenOption; import java.util.*; -import org.apache.commons.io.IOUtils; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -351,7 +350,8 @@ private void testMergingSpills( private void testMergingSpills( boolean transferToEnabled, boolean useShuffleWriterPlugin) throws IOException { - final UnsafeShuffleWriter writer = createWriter(transferToEnabled, useShuffleWriterPlugin); + final UnsafeShuffleWriter writer = + createWriter(transferToEnabled, useShuffleWriterPlugin); final ArrayList> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { dataToWrite.add(new Tuple2<>(i, i)); @@ -649,7 +649,8 @@ public ShuffleMapOutputWriter newMapOutputWriter( try { if (!mergedOutputFile.exists() && !mergedOutputFile.createNewFile()) { throw new IllegalStateException( - String.format("Failed to create merged output file %s.", mergedOutputFile.getAbsolutePath())); + String.format("Failed to create merged output file %s.", + mergedOutputFile.getAbsolutePath())); } } catch (IOException e) { throw new RuntimeException(e); diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 6efe73cb659d..3ba8f7017acf 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -168,7 +168,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { logInfo(s"Application $appId timed out. Removing shuffle files.") connectedApps.remove(appId) - applicationRemoved(appId, true) + // TODO: Write removal logic } } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala index 77bd66b608e7..44e843cdb73a 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkKubernetesClientFactory.scala @@ -21,11 +21,13 @@ import java.io.File import com.google.common.base.Charsets import com.google.common.io.Files import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient} +import io.fabric8.kubernetes.client.Config._ import io.fabric8.kubernetes.client.utils.HttpClientUtils import okhttp3.Dispatcher import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.util.ThreadUtils /** @@ -34,6 +36,35 @@ import org.apache.spark.util.ThreadUtils * options for different components. */ private[spark] object SparkKubernetesClientFactory { + def getDriverKubernetesClient(conf: SparkConf, masterURL: String): KubernetesClient = { + val wasSparkSubmittedInClusterMode = conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) + val (authConfPrefix, + apiServerUri, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { + require(conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, + "If the application is deployed using spark-submit in cluster mode, the driver pod name " + + "must be provided.") + (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, + KUBERNETES_MASTER_INTERNAL_URL, + Some(new File(KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), + Some(new File(KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) + } else { + (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, + KubernetesUtils.parseMasterUrl(masterURL), + None, + None) + } + + val kubernetesClient = createKubernetesClient( + apiServerUri, + Some(conf.get(KUBERNETES_NAMESPACE)), + authConfPrefix, + conf, + defaultServiceAccountToken, + defaultServiceAccountCaCrt) + kubernetesClient + } def createKubernetesClient( master: String, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala index 435a5f1461c9..aa4ce28aeb6b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsSnapshot.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging /** * An immutable view of the current executor pods that are running in the cluster. */ -private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, ExecutorPodState]) { +private[spark] case class ExecutorPodsSnapshot(executorPods: Map[Long, SparkPodState]) { import ExecutorPodsSnapshot._ @@ -42,15 +42,15 @@ object ExecutorPodsSnapshot extends Logging { ExecutorPodsSnapshot(toStatesByExecutorId(executorPods)) } - def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, ExecutorPodState]) + def apply(): ExecutorPodsSnapshot = ExecutorPodsSnapshot(Map.empty[Long, SparkPodState]) - private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, ExecutorPodState] = { + private def toStatesByExecutorId(executorPods: Seq[Pod]): Map[Long, SparkPodState] = { executorPods.map { pod => (pod.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL).toLong, toState(pod)) }.toMap } - private def toState(pod: Pod): ExecutorPodState = { + private def toState(pod: Pod): SparkPodState = { if (isDeleted(pod)) { PodDeleted(pod) } else { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index b31fbb420ed6..fd245f8f5f31 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -42,32 +42,8 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit sc: SparkContext, masterURL: String, scheduler: TaskScheduler): SchedulerBackend = { - val wasSparkSubmittedInClusterMode = sc.conf.get(KUBERNETES_DRIVER_SUBMIT_CHECK) - val (authConfPrefix, - apiServerUri, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) = if (wasSparkSubmittedInClusterMode) { - require(sc.conf.get(KUBERNETES_DRIVER_POD_NAME).isDefined, - "If the application is deployed using spark-submit in cluster mode, the driver pod name " + - "must be provided.") - (KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX, - KUBERNETES_MASTER_INTERNAL_URL, - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), - Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - } else { - (KUBERNETES_AUTH_CLIENT_MODE_PREFIX, - KubernetesUtils.parseMasterUrl(masterURL), - None, - None) - } - - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( - apiServerUri, - Some(sc.conf.get(KUBERNETES_NAMESPACE)), - authConfPrefix, - sc.conf, - defaultServiceAccountToken, - defaultServiceAccountCaCrt) + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + sc.conf, masterURL) if (sc.conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) { KubernetesUtils.loadPodFromTemplate( diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala new file mode 100644 index 000000000000..e5d9594fc3d5 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/SparkPodStates.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import java.util.Locale + +import io.fabric8.kubernetes.api.model.Pod + +import org.apache.spark.internal.Logging + +sealed trait SparkPodState { + def pod: Pod +} + +case class PodRunning(pod: Pod) extends SparkPodState + +case class PodPending(pod: Pod) extends SparkPodState + +sealed trait FinalPodState extends SparkPodState + +case class PodSucceeded(pod: Pod) extends FinalPodState + +case class PodFailed(pod: Pod) extends FinalPodState + +case class PodDeleted(pod: Pod) extends FinalPodState + +case class PodUnknown(pod: Pod) extends SparkPodState + +object SparkPodState extends Logging { + def toState(pod: Pod): SparkPodState = { + if (isDeleted(pod)) { + PodDeleted(pod) + } else { + val phase = pod.getStatus.getPhase.toLowerCase(Locale.ROOT) + phase match { + case "pending" => + PodPending(pod) + case "running" => + PodRunning(pod) + case "failed" => + PodFailed(pod) + case "succeeded" => + PodSucceeded(pod) + case _ => + logWarning(s"Received unknown phase $phase for executor pod with name" + + s" ${pod.getMetadata.getName} in namespace ${pod.getMetadata.getNamespace}") + PodUnknown(pod) + } + } + } + + private def isDeleted(pod: Pod): Boolean = pod.getMetadata.getDeletionTimestamp != null +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala new file mode 100644 index 000000000000..ec3e6d0483d9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.shuffle + +import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} +import java.util.concurrent.locks.ReentrantReadWriteLock + +import io.fabric8.kubernetes.api.model.Pod +import io.fabric8.kubernetes.client.{KubernetesClient, KubernetesClientException, Watch, Watcher} +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler.cluster.k8s._ +import org.apache.spark.scheduler.cluster.k8s.SparkPodState +import org.apache.spark.shuffle._ +import org.apache.spark.util.Utils + +class KubernetesShuffleServiceAddressProvider( + kubernetesClient: KubernetesClient, + pollForPodsExecutor: ScheduledExecutorService, + podLabels: Map[String, String], + namespace: String, + portNumber: Int) + extends ShuffleServiceAddressProvider with Logging { + + // General implementation remark: this bears a strong resemblance to ExecutorPodsSnapshotsStore, + // but we don't need all "in-between" lists of all executor pods, just the latest known list + // when we query in getShuffleServiceAddresses. + + private val podsUpdateLock = new ReentrantReadWriteLock() + + private val shuffleServicePods = mutable.HashMap.empty[String, Pod] + + private var shuffleServicePodsWatch: Watch = _ + private var pollForPodsTask: ScheduledFuture[_] = _ + + override def start(): Unit = { + pollForPods() + pollForPodsTask = pollForPodsExecutor.scheduleWithFixedDelay( + () => pollForPods(), 0, 10, TimeUnit.SECONDS) + shuffleServicePodsWatch = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava).watch(new PutPodsInCacheWatcher()) + } + + override def stop(): Unit = { + Utils.tryLogNonFatalError { + if (pollForPodsTask != null) { + pollForPodsTask.cancel(false) + } + } + + Utils.tryLogNonFatalError { + if (shuffleServicePodsWatch != null) { + shuffleServicePodsWatch.close() + } + } + + Utils.tryLogNonFatalError { + kubernetesClient.close() + } + } + + override def getShuffleServiceAddresses(): List[(String, Int)] = { + val readLock = podsUpdateLock.readLock() + readLock.lock() + try { + val addresses = shuffleServicePods.values.map(pod => { + (pod.getStatus.getPodIP, portNumber) + }).toList + logInfo(s"Found remote shuffle service addresses at $addresses.") + addresses + } finally { + readLock.unlock() + } + } + + private def pollForPods(): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + val allPods = kubernetesClient + .pods() + .inNamespace(namespace) + .withLabels(podLabels.asJava) + .list() + shuffleServicePods.clear() + allPods.getItems.asScala.foreach(updatePod) + } finally { + writeLock.unlock() + } + } + + private def updatePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only update pods under lock.") + val state = SparkPodState.toState(pod) + state match { + case PodPending(_) | PodFailed(_) | PodSucceeded(_) | PodDeleted(_) => + shuffleServicePods.remove(pod.getMetadata.getName) + case PodRunning(_) => + shuffleServicePods.put(pod.getMetadata.getName, pod) + case _ => + logWarning(s"Unknown state $state for pod named ${pod.getMetadata.getName}") + } + } + + private def deletePod(pod: Pod): Unit = { + require(podsUpdateLock.isWriteLockedByCurrentThread, "Should only delete under lock.") + shuffleServicePods.remove(pod.getMetadata.getName) + } + + private class PutPodsInCacheWatcher extends Watcher[Pod] { + override def eventReceived(action: Watcher.Action, pod: Pod): Unit = { + val writeLock = podsUpdateLock.writeLock() + writeLock.lock() + try { + updatePod(pod) + } finally { + writeLock.unlock() + } + } + + override def onClose(e: KubernetesClientException): Unit = {} + } + + private implicit def toRunnable(func: () => Unit): Runnable = { + new Runnable { + override def run(): Unit = func() + } + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala new file mode 100644 index 000000000000..39a31f645ccf --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s.shuffle + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory +import org.apache.spark.internal.{config => C} +import org.apache.spark.shuffle._ +import org.apache.spark.util.ThreadUtils + +class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { + override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") + + override def create(conf: SparkConf): ShuffleServiceAddressProvider = { + if (conf.get(C.K8S_SHUFFLE_SERVICE_ENABLED)) { + val kubernetesClient = SparkKubernetesClientFactory.getDriverKubernetesClient( + conf, conf.get("spark.master")) + val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( + "poll-shuffle-service-pods", 1) + val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS) + val shuffleServicePodsNamespace = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE) + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + + s" shuffle service must be defined by" + + s" ${KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") + require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") + + val port: Int = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT) + new KubernetesShuffleServiceAddressProvider( + kubernetesClient, + pollForPodsExecutor, + shuffleServiceLabels.toMap, + shuffleServicePodsNamespace.get, + port) + } else DefaultShuffleServiceAddressProvider + } +} From 646f1bf401023de75e58da5005e39dcf57d41538 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 28 Dec 2018 11:03:05 -0500 Subject: [PATCH 04/30] running tests on driver and executor logic --- .../external/ExternalShuffleDataIO.java | 2 +- .../scala/org/apache/spark/SparkEnv.scala | 14 ++++--- .../spark/internal/config/package.scala | 6 +++ .../apache/spark/storage/BlockManager.scala | 40 ++++++++++--------- .../spark/examples/SkewedGroupByTest.scala | 2 +- .../org/apache/spark/deploy/k8s/Config.scala | 6 +-- .../KubernetesExternalShuffleService.scala | 40 +++++++++++++++---- ...ernetesShuffleServiceAddressProvider.scala | 8 ++-- ...ShuffleServiceAddressProviderFactory.scala | 12 +++--- .../src/main/dockerfiles/spark/Dockerfile | 2 +- .../cluster/mesos/MesosClusterManager.scala | 4 ++ .../cluster/YarnClusterManager.scala | 3 ++ sbin/start-k8s-shuffle-service.sh | 34 ++++++++++++++++ 13 files changed, 127 insertions(+), 46 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/{ => k8s}/KubernetesShuffleServiceAddressProvider.scala (96%) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/{ => k8s}/KubernetesShuffleServiceAddressProviderFactory.scala (85%) create mode 100644 sbin/start-k8s-shuffle-service.sh diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index cf79e845a9cc..d6d6167259cf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -45,7 +45,7 @@ public ExternalShuffleDataIO( @Override public void initialize() { - // TODO: hmmmm? maybe register? idk + // TODO: move registerDriver and registerExecutor here } @Override diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 68ef2ed98107..45aabc05f49b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import java.util.{Locale, ServiceLoader} import com.google.common.collect.MapMaker -import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Properties @@ -302,15 +301,20 @@ object SparkEnv extends Logging { val broadcastManager = new BroadcastManager(isDriver, conf, securityManager) val mapOutputTracker = if (isDriver) { - val loader = Utils.getContextOrSparkClassLoader val master = conf.get("spark.master") - val serviceLoaders = - ServiceLoader.load(classOf[ShuffleServiceAddressProviderFactory], loader) - .asScala.filter(_.canCreate(conf.get("spark.master"))) + val shuffleProvider = conf.get(SHUFFLE_SERVICE_PROVIDER_CLASS) + .map(clazz => Utils.loadExtensions( + classOf[ShuffleServiceAddressProviderFactory], + Seq(clazz), conf)).getOrElse(Seq()) + val serviceLoaders = shuffleProvider + .filter(_.canCreate(conf.get("spark.master"))) if (serviceLoaders.size > 1) { throw new SparkException( s"Multiple external cluster managers registered for the url $master: $serviceLoaders") } + val loader = Utils.getContextOrSparkClassLoader + logInfo(s"Loader: $loader") + logInfo(s"Service loader: $serviceLoaders") val shuffleServiceAddressProvider = serviceLoaders.headOption .map(_.create(conf)) .getOrElse(DefaultShuffleServiceAddressProvider) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 533672f809f4..c249565fa151 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -444,6 +444,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_SERVICE_PROVIDER_CLASS = + ConfigBuilder("spark.shuffle.provider.plugin.class") + .doc("Experimental. Specify a class that can handle detecting shuffle service pods.") + .stringConf + .createOptional + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = ConfigBuilder("spark.shuffle.io.plugin.class") .doc("Experimental. Specify a class that can handle reading and writing shuffle blocks to" + diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2a935bc2fee3..2ea7fdf27cca 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -184,7 +184,7 @@ private[spark] class BlockManager( } } - private var remoteShuffleServiceAddress: Option[(String, Int)] = None + private var remoteShuffleServiceAddress: List[(String, Int)] = List() var blockManagerId: BlockManagerId = _ @@ -266,16 +266,14 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id - if (!blockManagerId.isDriver && externalk8sShuffleServiceEnabled) { - remoteShuffleServiceAddress = Random.shuffle(mapOutputTracker + if (externalk8sShuffleServiceEnabled) { + remoteShuffleServiceAddress = mapOutputTracker .trackerEndpoint - .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses)) - .headOption + .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses) } - shuffleServerId = if (externalk8sShuffleServiceEnabled && !blockManagerId.isDriver) { - val (hostName, port) = remoteShuffleServiceAddress.getOrElse( - throw new SparkException("No K8S External Shuffle Addresses")) + shuffleServerId = if (externalk8sShuffleServiceEnabled) { + val (hostName, port) = Random.shuffle(remoteShuffleServiceAddress).head BlockManagerId(executorId, hostName, port) } else if (externalNonK8sShuffleService) { logInfo(s"external shuffle service port = $externalShuffleServicePort") @@ -285,18 +283,22 @@ private[spark] class BlockManager( } if (externalk8sShuffleServiceEnabled && blockManagerId.isDriver) { - // Register Drivers' configuration with the k8s shuffle service - shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] - .registerDriverWithShuffleService( - shuffleServerId.host, shuffleServerId.port, - conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), - conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) + // Register Drivers' configuration with the k8s shuffle services + remoteShuffleServiceAddress.foreach { ssId => + shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] + .registerDriverWithShuffleService( + ssId._1, ssId._2, + conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + s"${conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), + conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) + } } else if (externalk8sShuffleServiceEnabled && !blockManagerId.isDriver) { - shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] - .registerExecutorWithShuffleService( - shuffleServerId.host, shuffleServerId.port, appId, - shuffleServerId.executorId, shuffleManager.getClass.getName) + remoteShuffleServiceAddress.foreach { ssId => + shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] + .registerExecutorWithShuffleService( + ssId._1, ssId._2, appId, + shuffleServerId.executorId, shuffleManager.getClass.getName) + } } else if (externalNonK8sShuffleService && !blockManagerId.isDriver) { // Register Executors' configuration with the local shuffle service, if one should exist. registerWithExternalShuffleServer() diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index 4d3c34041bc1..a1e6f83f7cba 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -23,7 +23,7 @@ import java.util.Random import org.apache.spark.sql.SparkSession /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + * Usage: SkewedGroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] */ object SkewedGroupByTest { def main(args: Array[String]) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index b2c67ba7f920..44d5a2c22dd4 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -291,14 +291,14 @@ private[spark] object Config extends Logging { val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE = ConfigBuilder("spark.kubernetes.shuffle.service.remote.pods.namespace") - .doc("Namespace of the pods that are running the shuffle service instances for backing up" + - " shuffle data.") + .doc("Namespace of the pods that are running the shuffle service instances for remote" + + " pushing of shuffle data.") .stringConf .createOptional val KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT = ConfigBuilder("spark.kubernetes.shuffle.service.remote.port") - .doc("Port of the shuffle services that will back up the application's shuffle data.") + .doc("Port of the external k8s shuffle service pods") .intConf .createWithDefault(7337) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 3ba8f7017acf..c83683f417af 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -21,6 +21,7 @@ import java.nio.ByteBuffer import java.nio.file.Files import java.nio.file.Paths import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.function.BiFunction import scala.collection.JavaConverters._ @@ -49,7 +50,9 @@ private[spark] class KubernetesExternalShuffleBlockHandler( // Stores a map of app id to app state (timeout value and last heartbeat) private val connectedApps = new ConcurrentHashMap[String, AppState]() - private val registeredExecutors = new ConcurrentHashMap[AppExecId, ExecutorShuffleInfo]() + private val registeredExecutors = + new ConcurrentHashMap[String, Map[String, ExecutorShuffleInfo]]() + private val knownManagers = Array( "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") @@ -74,8 +77,19 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } val executorShuffleInfo = new ExecutorShuffleInfo( Array(executorDir.getAbsolutePath), 1, shuffleManager) + val execMap = Map(execId -> executorShuffleInfo) + registeredExecutors.merge(appId, execMap, + new BiFunction[ + Map[String, ExecutorShuffleInfo], + Map[String, ExecutorShuffleInfo], + Map[String, ExecutorShuffleInfo]]() { + override def apply( + t: Map[String, ExecutorShuffleInfo], u: Map[String, ExecutorShuffleInfo]): + Map[String, ExecutorShuffleInfo] = { + t ++ u + } + }) logInfo(s"Registering executor ${fullId} with ${executorShuffleInfo}") - registeredExecutors.put(fullId, executorShuffleInfo) case RegisterDriverParam(appId, appState) => val address = client.getSocketAddress @@ -86,7 +100,12 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logWarning(s"Received a registration request from app $appId, but it was already " + s"registered") } + val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile + if (!driverDir.mkdir()) { + throw new RuntimeException(s"Failed to create dir ${driverDir.getAbsolutePath}") + } connectedApps.put(appId, appState) + registeredExecutors.put(appId, Map[String, ExecutorShuffleInfo]()) callback.onSuccess(ByteBuffer.allocate(0)) case Heartbeat(appId) => @@ -111,6 +130,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( header match { case UploadParam( appId, execId, shuffleId, mapId, partitionId) => + logInfo(s"Received upload param from app $appId from $execId") getFileWriterStreamCallback( appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) case _ => super.handleStream(header, client, callback) @@ -124,17 +144,21 @@ private[spark] class KubernetesExternalShuffleBlockHandler( mapId: Int, extension: String, fileType: FileWriterStreamCallback.FileType): StreamCallbackWithID = { - val fullId = new AppExecId(appId, execId) - val executor = registeredExecutors.get(fullId) + val execMap = registeredExecutors.get(appId) + if (execMap == null) { + throw new RuntimeException( + s"appId=$appId is not registered for remote shuffle") + } + val executor = execMap(execId) if (executor == null) { throw new RuntimeException( - s"Executor is not registered for remote shuffle (appId=$appId, execId=$execId)") + s"App is not registered for remote shuffle (appId=$appId, execId=$execId)") } - val backedUpFile = + val file = ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0." + extension) val streamCallback = - new FileWriterStreamCallback(fullId, shuffleId, mapId, backedUpFile, fileType) + new FileWriterStreamCallback(new AppExecId(appId, execId), shuffleId, mapId, file, fileType) streamCallback.open() streamCallback } @@ -168,6 +192,8 @@ private[spark] class KubernetesExternalShuffleBlockHandler( if (now - appState.lastHeartbeat > appState.heartbeatTimeout * 1000 * 1000) { logInfo(s"Application $appId timed out. Removing shuffle files.") connectedApps.remove(appId) + applicationRemoved(appId, false) + registeredExecutors.remove(appId) // TODO: Write removal logic } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala similarity index 96% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala index ec3e6d0483d9..63074f6f14d7 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProvider.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.deploy.k8s.shuffle +package org.apache.spark.shuffle.k8s import java.util.concurrent.{ScheduledExecutorService, ScheduledFuture, TimeUnit} import java.util.concurrent.locks.ReentrantReadWriteLock @@ -26,11 +26,11 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.cluster.k8s._ -import org.apache.spark.scheduler.cluster.k8s.SparkPodState -import org.apache.spark.shuffle._ +import org.apache.spark.scheduler.cluster.k8s.{SparkPodState, _} +import org.apache.spark.shuffle.ShuffleServiceAddressProvider import org.apache.spark.util.Utils + class KubernetesShuffleServiceAddressProvider( kubernetesClient: KubernetesClient, pollForPodsExecutor: ScheduledExecutorService, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala similarity index 85% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala index 39a31f645ccf..57e68d405329 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/KubernetesShuffleServiceAddressProviderFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProviderFactory.scala @@ -15,16 +15,17 @@ * limitations under the License. */ -package org.apache.spark.deploy.k8s.shuffle +package org.apache.spark.shuffle.k8s import org.apache.spark.SparkConf import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.internal.{config => C} +import org.apache.spark.internal.{config => C, Logging} import org.apache.spark.shuffle._ import org.apache.spark.util.ThreadUtils -class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddressProviderFactory { +class KubernetesShuffleServiceAddressProviderFactory + extends ShuffleServiceAddressProviderFactory with Logging { override def canCreate(masterUrl: String): Boolean = masterUrl.startsWith("k8s://") override def create(conf: SparkConf): ShuffleServiceAddressProvider = { @@ -33,12 +34,13 @@ class KubernetesShuffleServiceAddressProviderFactory extends ShuffleServiceAddre conf, conf.get("spark.master")) val pollForPodsExecutor = ThreadUtils.newDaemonThreadPoolScheduledExecutor( "poll-shuffle-service-pods", 1) + logInfo("Beginning to search for K8S pods that act as an External Shuffle Service") val shuffleServiceLabels = conf.getAllWithPrefix(KUBERNETES_REMOTE_SHUFFLE_SERVICE_LABELS) val shuffleServicePodsNamespace = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE) - require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the backup" + + require(shuffleServicePodsNamespace.isDefined, "Namespace for the pods running the external" + s" shuffle service must be defined by" + s" ${KUBERNETES_REMOTE_SHUFFLE_SERVICE_PODS_NAMESPACE.key}") - require(shuffleServiceLabels.nonEmpty, "Requires labels for the backup shuffle service pods.") + require(shuffleServiceLabels.nonEmpty, "Requires labels for external shuffle service pods") val port: Int = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_PORT) new KubernetesShuffleServiceAddressProvider( diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 084304032470..c37e17f92eda 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -28,7 +28,7 @@ ARG spark_uid=185 RUN set -ex && \ apk upgrade --no-cache && \ - apk add --no-cache bash tini libc6-compat linux-pam krb5 krb5-libs && \ + apk add --no-cache bash tini libc6-compat linux-pam krb5 krb5-libs procps && \ mkdir -p /opt/spark && \ mkdir -p /opt/spark/examples && \ mkdir -p /opt/spark/work-dir && \ diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index da71f8f9e407..48ef8df37ecc 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster.mesos import org.apache.spark.SparkContext import org.apache.spark.internal.config._ import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Mesos scheduler and backend @@ -59,6 +60,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index 64cd1bd08800..f3c9e3e2741f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} +import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleServiceAddressProvider} /** * Cluster Manager for creation of Yarn scheduler and backend @@ -53,4 +54,6 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } + override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + DefaultShuffleServiceAddressProvider } diff --git a/sbin/start-k8s-shuffle-service.sh b/sbin/start-k8s-shuffle-service.sh new file mode 100644 index 000000000000..84ee30320220 --- /dev/null +++ b/sbin/start-k8s-shuffle-service.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Starts the K8S external shuffle server on the machine this script is executed on. +# TODO: Describe K8s ESS +# +# Usage: start-k8s-shuffle-service.sh +# +# + +if [ -z "${SPARK_HOME}" ]; then + export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi + +. "${SPARK_HOME}/sbin/spark-config.sh" +. "${SPARK_HOME}/bin/load-spark-env.sh" + +exec "${SPARK_HOME}/sbin"/spark-daemon.sh start org.apache.spark.deploy.k8s.KubernetesExternalShuffleService 1 From 5555bc9673116f0aeb8f54f079cdbfe3d891f640 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Fri, 28 Dec 2018 18:44:19 -0500 Subject: [PATCH 05/30] testing executor writing --- .../protocol/OpenShufflePartition.java | 1 + .../external/ExternalShuffleDataIO.java | 16 ++---- .../shuffle/sort/SortShuffleManager.scala | 11 ++-- .../apache/spark/storage/BlockManager.scala | 8 ++- .../spark/examples/GroupByShuffleTest.scala | 48 +++++++++++++++++ .../KubernetesExternalShuffleService.scala | 52 ++++++++++++++----- 6 files changed, 104 insertions(+), 32 deletions(-) create mode 100644 examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java index 408be7fad26d..78f6f5dc5910 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java @@ -53,6 +53,7 @@ public String toString() { .add("execId", execId) .add("shuffleId", shuffleId) .add("mapId", mapId) + .add("partitionId", partitionId) .toString(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index d6d6167259cf..2500711b69fd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -8,7 +8,7 @@ import org.apache.spark.shuffle.api.ShuffleReadSupport; import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.apache.spark.SecurityManager; -import org.apache.spark.util.Utils; +import org.apache.spark.storage.BlockManager; public class ExternalShuffleDataIO implements ShuffleDataIO { @@ -16,6 +16,7 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private static final String DEFAULT_SHUFFLE_PORT = "7337"; private static final SparkEnv sparkEnv = SparkEnv.get(); + private static final BlockManager blockManager = sparkEnv.blockManager(); private final SparkConf sparkConf; private final TransportConf conf; @@ -30,17 +31,10 @@ public ExternalShuffleDataIO( this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 2); this.securityManager = sparkEnv.securityManager(); - this.hostname = sparkEnv.blockManager().blockTransferService().hostName(); + this.hostname = blockManager.getRandomShuffleHost(); + this.port = blockManager.getRandomShufflePort(); - int tmpPort = Integer.parseInt(Utils.getSparkOrYarnConfig( - sparkConf, SHUFFLE_SERVICE_PORT_CONFIG, DEFAULT_SHUFFLE_PORT)); - if (tmpPort == 0) { - this.port = Integer.parseInt(sparkConf.get(SHUFFLE_SERVICE_PORT_CONFIG)); - } else { - this.port = tmpPort; - } - - this.execId = SparkEnv.get().blockManager().shuffleServerId().executorId(); + this.execId = blockManager.shuffleServerId().executorId(); } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index 53a5c4f3afba..ba56da9089a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -77,11 +77,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager " Shuffle will continue to spill to disk when necessary.") } - private val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - - shuffleIoPlugin.foreach(_.initialize()) - /** * A mapping from shuffle ids to the number of mappers producing output for those shuffles. */ @@ -124,6 +119,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) + shuffleIoPlugin.foreach(_.initialize()) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], conf.getAppId, @@ -143,6 +141,9 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get + val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) + shuffleIoPlugin.foreach(_.initialize()) handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 2ea7fdf27cca..233dd1d4bfde 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -185,6 +185,7 @@ private[spark] class BlockManager( } private var remoteShuffleServiceAddress: List[(String, Int)] = List() + private var randomShuffleServiceAddress: (String, Int) = null var blockManagerId: BlockManagerId = _ @@ -273,8 +274,8 @@ private[spark] class BlockManager( } shuffleServerId = if (externalk8sShuffleServiceEnabled) { - val (hostName, port) = Random.shuffle(remoteShuffleServiceAddress).head - BlockManagerId(executorId, hostName, port) + randomShuffleServiceAddress = Random.shuffle(remoteShuffleServiceAddress).head + BlockManagerId(executorId, randomShuffleServiceAddress._1, randomShuffleServiceAddress._2) } else if (externalNonK8sShuffleService) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) @@ -366,6 +367,9 @@ private[spark] class BlockManager( } } + private[spark] def getRandomShuffleHost: String = randomShuffleServiceAddress._1 + private[spark] def getRandomShufflePort: Int = randomShuffleServiceAddress._2 + /** * Re-register with the master and report all blocks to it. This will be called by the heart beat * thread if our heartbeat to the block manager indicates that we were not registered. diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala new file mode 100644 index 000000000000..9d056a9f6f7b --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples + +import java.util.Random + +import org.apache.spark.sql.SparkSession + +/** + * Usage: GroupByShuffleTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ +object GroupByShuffleTest { + def main(args: Array[String]) { + val spark = SparkSession + .builder + .appName("GroupByShuffle Test") + .getOrCreate() + + val words = Array("one", "two", "two", "three", "three", "three") + val wordPairsRDD = spark.sparkContext.parallelize(words).map(word => (word, 1)) + + val wordCountsWithGroup = wordPairsRDD + .groupByKey() + .map(t => (t._1, t._2.sum)) + .collect() + + println(wordCountsWithGroup.mkString(",")) + + spark.stop() + } +} +// scalastyle:on println diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index c83683f417af..9ddc32f28137 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -17,12 +17,14 @@ package org.apache.spark.deploy.k8s +import java.io.{File, FileInputStream} import java.nio.ByteBuffer import java.nio.file.Files import java.nio.file.Paths import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.function.BiFunction +import org.apache.commons.io.IOUtils import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} @@ -118,6 +120,13 @@ private[spark] class KubernetesExternalShuffleBlockHandler( case None => logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + s"address $address, appId '$appId').") + case OpenParam(appId, execId, shuffleId, mapId, partitionId) => + logInfo(s"Received open param from app $appId from $execId") + val file = getFile( + appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + val fileInputStream = new FileInputStream(file) + val bytes = IOUtils.toByteArray(fileInputStream) + callback.onSuccess(ByteBuffer.wrap(bytes)) } case _ => super.handleMessage(message, client, callback) } @@ -133,7 +142,8 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logInfo(s"Received upload param from app $appId from $execId") getFileWriterStreamCallback( appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) - case _ => super.handleStream(header, client, callback) + case _ => + super.handleStream(header, client, callback) } } @@ -144,25 +154,34 @@ private[spark] class KubernetesExternalShuffleBlockHandler( mapId: Int, extension: String, fileType: FileWriterStreamCallback.FileType): StreamCallbackWithID = { - val execMap = registeredExecutors.get(appId) - if (execMap == null) { - throw new RuntimeException( - s"appId=$appId is not registered for remote shuffle") - } - val executor = execMap(execId) - if (executor == null) { - throw new RuntimeException( - s"App is not registered for remote shuffle (appId=$appId, execId=$execId)") - } - val file = - ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0." + extension) + val file = getFile(appId, execId, shuffleId, mapId, extension, fileType) val streamCallback = new FileWriterStreamCallback(new AppExecId(appId, execId), shuffleId, mapId, file, fileType) streamCallback.open() streamCallback } + private def getFile( + appId: String, + execId: String, + shuffleId: Int, + mapId: Int, + extension: String, + fileType: FileWriterStreamCallback.FileType): File = { + val execMap = registeredExecutors.get(appId) + if (execMap == null) { + throw new RuntimeException( + s"appId=$appId is not registered for remote shuffle") + } + val executor = execMap(execId) + if (executor == null) { + throw new RuntimeException( + s"App is not registered for remote shuffle (appId=$appId, execId=$execId)") + } + ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, + "shuffle_" + shuffleId + "_" + mapId + "_0." + extension) + } + /** An extractor object for matching BlockTransferMessages. */ private object RegisterDriverParam { def unapply(r: RegisterDriver): Option[(String, AppState)] = @@ -183,6 +202,11 @@ private[spark] class KubernetesExternalShuffleBlockHandler( Some((e.appId, e.execId, e.shuffleManager)) } + private object OpenParam { + def unapply(o: OpenShufflePartition): Option[(String, String, Int, Int, Int)] = + Some((o.appId, o.execId, o.shuffleId, o.mapId, o.partitionId)) + } + private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) private class CleanerThread extends Runnable { From 7fabde733e2522f88dca3702e786eba5fc9a20b0 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 2 Jan 2019 13:59:26 -0500 Subject: [PATCH 06/30] added index file write and data read --- .../protocol/BlockTransferMessage.java | 7 +- .../protocol/UploadShuffleIndexStream.java | 102 ++++++++++++++++++ .../shuffle/api/ShuffleMapOutputWriter.java | 2 +- .../external/ExternalShuffleWriteSupport.java | 40 ++++++- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/UnsafeShuffleWriter.java | 3 +- .../util/collection/ExternalSorter.scala | 2 +- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../apache/spark/SplitFilesShuffleIO.scala | 2 +- .../KubernetesExternalShuffleService.scala | 65 ++++++++--- 10 files changed, 198 insertions(+), 29 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 8185193061a7..c38971231de7 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -41,7 +41,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), - OPEN_SHUFFLE_PARTITION(8), REGISTER_EXECUTOR_WITH_EXTERNAL(9); + UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9), REGISTER_EXECUTOR_WITH_EXTERNAL(19); private final byte id; @@ -68,8 +68,9 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); case 7: return UploadShufflePartitionStream.decode(buf); - case 8: return OpenShufflePartition.decode(buf); - case 9: return RegisterExecutorWithExternal.decode(buf); + case 8: return UploadShuffleIndexStream.decode(buf); + case 9: return OpenShufflePartition.decode(buf); + case 10: return RegisterExecutorWithExternal.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java new file mode 100644 index 000000000000..025459afeab3 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Upload shuffle partition request to the External Shuffle Service. + * This request should also include the driverHostPort for the sake of + * setting up a driver heartbeat to monitor heartbeat + */ +public class UploadShuffleIndexStream extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + public final int mapId; + + public UploadShuffleIndexStream( + String appId, + String execId, + int shuffleId, + int mapId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShufflePartitionStream) { + UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + return Objects.equal(appId, o.appId) + && execId == o.execId + && shuffleId == o.shuffleId + && mapId == o.mapId; + } + return false; + } + + @Override + protected Type type() { + return Type.UPLOAD_SHUFFLE_INDEX_STREAM; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId, shuffleId, mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static UploadShuffleIndexStream decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new UploadShuffleIndexStream(appId, execId, shuffleId, mapId); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 06415dba72d3..f0f7d5ade602 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -21,7 +21,7 @@ public interface ShuffleMapOutputWriter { ShufflePartitionWriter newPartitionWriter(int partitionId); - void commitAllPartitions(); + void commitAllPartitions(long[] partitionLengths); void abort(Exception exception); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 60ae1c9fbe6d..620931e9e1a7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -2,12 +2,15 @@ import com.google.common.collect.Lists; import org.apache.spark.network.TransportContext; +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; +import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -15,6 +18,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.nio.ByteBuffer; +import java.nio.LongBuffer; import java.util.List; public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { @@ -55,14 +60,43 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { return new ExternalShufflePartitionWriter( client, appId, execId, shuffleId, mapId, partitionId); } catch (Exception e) { - logger.error("Encountered error while creating transport client"); + logger.error("Encountered error while creating transport client", e); throw new RuntimeException(e); // what is standard practice here? } } @Override - public void commitAllPartitions() { - logger.info("Commiting all partitions"); + public void commitAllPartitions(long[] partitionLengths) { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully uploaded index"); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Encountered an error uploading index", e); + } + }; + try { + TransportClient client = clientFactory.createClient(hostname, port); + logger.info("Committing all partitions with a creation of an index file"); + ByteBuffer streamHeader = new UploadShuffleIndexStream( + appId, execId, shuffleId, mapId).toByteBuffer(); + // Size includes first 0L offset + ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); + LongBuffer longBuffer = byteBuffer.asLongBuffer(); + Long offset = 0L; + longBuffer.put(offset); + for (Long length: partitionLengths) { + offset += length; + longBuffer.put(offset); + } + client.uploadStream(new NioManagedBuffer(streamHeader), + new NioManagedBuffer(byteBuffer), callback); + } catch (Exception e) { + logger.error("Encountered error while creating transport client", e); + } } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 823c36d051dd..2cdf0c4600ae 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -267,7 +267,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio } } } - mapOutputWriter.commitAllPartitions(); + mapOutputWriter.commitAllPartitions(lengths); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 32be62009511..5a882c08dd05 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -563,8 +563,7 @@ private long[] mergeSpillsWithPluggableWriter( throw e; } } - mapOutputWriter.commitAllPartitions(); - threwException = false; + mapOutputWriter.commitAllPartitions(partitionLengths); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 569c8bd092f3..01cc838474d8 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -778,7 +778,7 @@ private[spark] class ExternalSorter[K, V, C]( } } } - mapOutputWriter.commitAllPartitions() + mapOutputWriter.commitAllPartitions(lengths) } catch { case e: Exception => util.Utils.tryLogNonFatalError { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 539336cd4fd8..0b18aceef92d 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -696,7 +696,7 @@ public void abort(Exception failureReason) { } @Override - public void commitAllPartitions() { + public void commitAllPartitions(long[] partitionlegnths) { } diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index 3a68fded945b..f6ac1fcc05a1 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -56,7 +56,7 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { } } - override def commitAllPartitions(): Unit = {} + override def commitAllPartitions(partitionLengths: Array[Long]): Unit = {} override def abort(exception: Exception): Unit = {} } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 9ddc32f28137..684e550c1ce2 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -17,20 +17,20 @@ package org.apache.spark.deploy.k8s -import java.io.{File, FileInputStream} +import java.io.File import java.nio.ByteBuffer -import java.nio.file.Files -import java.nio.file.Paths -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.nio.file.{Files, Paths} +import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} import java.util.function.BiFunction -import org.apache.commons.io.IOUtils +import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.FileSegmentManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver._ @@ -44,22 +44,34 @@ import org.apache.spark.util.ThreadUtils */ private[spark] class KubernetesExternalShuffleBlockHandler( transportConf: TransportConf, - cleanerIntervalS: Long) + cleanerIntervals: Long, + indexCacheSize: String) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") - .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervalS, TimeUnit.SECONDS) + .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervals, TimeUnit.SECONDS) // Stores a map of app id to app state (timeout value and last heartbeat) private val connectedApps = new ConcurrentHashMap[String, AppState]() private val registeredExecutors = new ConcurrentHashMap[String, Map[String, ExecutorShuffleInfo]]() + private val indexCacheLoader = new CacheLoader[File, ShuffleIndexInformation]() { + override def load(file: File): ShuffleIndexInformation = new ShuffleIndexInformation(file) + } + private val shuffleIndexCache = CacheBuilder.newBuilder() + .maximumWeight(JavaUtils.byteStringAsBytes(indexCacheSize)) + .weigher(new Weigher[File, ShuffleIndexInformation]() { + override def weigh(file: File, indexInfo: ShuffleIndexInformation): Int = + indexInfo.getSize + }) + .build(indexCacheLoader) private val knownManagers = Array( "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") private final val shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile() + protected override def handleMessage( message: BlockTransferMessage, client: TransportClient, @@ -120,13 +132,23 @@ private[spark] class KubernetesExternalShuffleBlockHandler( case None => logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + s"address $address, appId '$appId').") - case OpenParam(appId, execId, shuffleId, mapId, partitionId) => - logInfo(s"Received open param from app $appId from $execId") - val file = getFile( - appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) - val fileInputStream = new FileInputStream(file) - val bytes = IOUtils.toByteArray(fileInputStream) - callback.onSuccess(ByteBuffer.wrap(bytes)) + } + case OpenParam(appId, execId, shuffleId, mapId, partitionId) => + logInfo(s"Received open param from app $appId from $execId") + val indexFile = getFile( + appId, execId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) + try { + val shuffleIndexInformation = shuffleIndexCache.get(indexFile) + val shuffleIndexRecord = shuffleIndexInformation.getIndex(partitionId) + val managedBuffer = new FileSegmentManagedBuffer( + transportConf, + getFile(appId, execId, shuffleId, mapId, + "data", FileWriterStreamCallback.FileType.DATA), + shuffleIndexRecord.getOffset, + shuffleIndexRecord.getLength) + callback.onSuccess(managedBuffer.nioByteBuffer()) + } catch { + case e: ExecutionException => logError(s"Unable to write index file $indexFile", e) } case _ => super.handleMessage(message, client, callback) } @@ -139,9 +161,14 @@ private[spark] class KubernetesExternalShuffleBlockHandler( header match { case UploadParam( appId, execId, shuffleId, mapId, partitionId) => + // TODO: Investigate whether we should use the partitionId for Index File creation logInfo(s"Received upload param from app $appId from $execId") getFileWriterStreamCallback( appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + case UploadIndexParam(appId, execId, shuffleId, mapId) => + logInfo(s"Received upload index param from app $appId from $execId") + getFileWriterStreamCallback( + appId, execId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) case _ => super.handleStream(header, client, callback) } @@ -197,6 +224,11 @@ private[spark] class KubernetesExternalShuffleBlockHandler( Some((u.appId, u.execId, u.shuffleId, u.mapId, u.partitionId)) } + private object UploadIndexParam { + def unapply(u: UploadShuffleIndexStream): Option[(String, String, Int, Int)] = + Some((u.appId, u.execId, u.shuffleId, u.mapId)) + } + private object RegisterExecutorParam { def unapply(e: RegisterExecutorWithExternal): Option[(String, String, String)] = Some((e.appId, e.execId, e.shuffleManager)) @@ -236,8 +268,9 @@ private[spark] class KubernetesExternalShuffleService( protected override def newShuffleBlockHandler( conf: TransportConf): ExternalShuffleBlockHandler = { - val cleanerIntervalS = this.conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) - new KubernetesExternalShuffleBlockHandler(conf, cleanerIntervalS) + val cleanerIntervals = this.conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) + val indexCacheSize = this.conf.get("spark.shuffle.service.index.cache.size", "100m") + new KubernetesExternalShuffleBlockHandler(conf, cleanerIntervals, indexCacheSize) } } From 109fbaa79f7541b19840ef4b9a599b6c7cd7aede Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 2 Jan 2019 17:15:07 -0500 Subject: [PATCH 07/30] fixing read issues --- .../shuffle/protocol/BlockTransferMessage.java | 2 +- .../protocol/UploadShuffleIndexStream.java | 4 +--- .../protocol/UploadShufflePartitionStream.java | 2 -- .../external/ExternalShufflePartitionReader.java | 16 +++++----------- .../external/ExternalShufflePartitionWriter.java | 4 ++-- .../external/ExternalShuffleWriteSupport.java | 5 ++++- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 1 + .../org/apache/spark/storage/BlockManager.scala | 3 +++ .../k8s/KubernetesExternalShuffleService.scala | 14 +++++++++----- 9 files changed, 26 insertions(+), 25 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index c38971231de7..b52a9f632e3b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -41,7 +41,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), - UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9), REGISTER_EXECUTOR_WITH_EXTERNAL(19); + UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9), REGISTER_EXECUTOR_WITH_EXTERNAL(10); private final byte id; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java index 025459afeab3..04ff66a986a0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java @@ -25,9 +25,7 @@ import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** - * Upload shuffle partition request to the External Shuffle Service. - * This request should also include the driverHostPort for the sake of - * setting up a driver heartbeat to monitor heartbeat + * Upload shuffle index request to the External Shuffle Service. */ public class UploadShuffleIndexStream extends BlockTransferMessage { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java index a72d81b84339..ee5629e74526 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -26,8 +26,6 @@ /** * Upload shuffle partition request to the External Shuffle Service. - * This request should also include the driverHostPort for the sake of - * setting up a driver heartbeat to monitor heartbeat */ public class UploadShufflePartitionStream extends BlockTransferMessage { public final String appId; diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index 9d6d9cd1f432..a31f02cd4b54 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -2,15 +2,13 @@ import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.server.OneForOneStreamManager; -import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; -import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.util.ByteBufferInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.io.SequenceInputStream; @@ -49,15 +47,11 @@ public InputStream fetchPartition(int reduceId) { ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); try { - StreamCombiningCallback callback = new StreamCombiningCallback(); - StreamHandle streamHandle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); - for (int i = 0; i < streamHandle.numChunks; i++) { - client.stream( - OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), - callback); + if (response.hasArray()) { + // use heap buffer; no array is created; only the reference is used + return new ByteArrayInputStream(response.array()); } - return callback.getCombinedInputStream(); + return new ByteBufferInputStream(response); } catch (Exception e) { logger.error("Encountered exception while trying to fetch blocks", e); throw new RuntimeException(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 61cda892dcdb..a486c7b36021 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -72,12 +72,12 @@ public void onFailure(Throwable e) { client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); totalLength += size; } catch (Exception e) { - logger.error("Encountered error while attempting to upload partition to ESS", e); client.close(); + logger.error("Encountered error while attempting to upload partition to ESS", e); throw new RuntimeException(e); } finally { - logger.info("Successfully sent partition to ESS"); client.close(); + logger.info("Successfully sent partition to ESS"); } return totalLength; } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 620931e9e1a7..92a605807324 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -78,8 +78,9 @@ public void onFailure(Throwable e) { logger.error("Encountered an error uploading index", e); } }; + final TransportClient client; try { - TransportClient client = clientFactory.createClient(hostname, port); + client = clientFactory.createClient(hostname, port); logger.info("Committing all partitions with a creation of an index file"); ByteBuffer streamHeader = new UploadShuffleIndexStream( appId, execId, shuffleId, mapId).toByteBuffer(); @@ -94,7 +95,9 @@ public void onFailure(Throwable e) { } client.uploadStream(new NioManagedBuffer(streamHeader), new NioManagedBuffer(byteBuffer), callback); + client.close(); } catch (Exception e) { + // Close client upon failure logger.error("Encountered error while creating transport client", e); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 5a882c08dd05..4e299034a893 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -564,6 +564,7 @@ private long[] mergeSpillsWithPluggableWriter( } } mapOutputWriter.commitAllPartitions(partitionLengths); + threwException = false; } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 233dd1d4bfde..daa52525b154 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -274,6 +274,9 @@ private[spark] class BlockManager( } shuffleServerId = if (externalk8sShuffleServiceEnabled) { + // TODO: Investigate better methods of load balancing + // note: might break if retry (as exec could write to one of the addresses + // it did not write to randomShuffleServiceAddress = Random.shuffle(remoteShuffleServiceAddress).head BlockManagerId(executorId, randomShuffleServiceAddress._1, randomShuffleServiceAddress._2) } else if (externalNonK8sShuffleService) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 684e550c1ce2..f77524c34855 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -24,8 +24,8 @@ import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} import java.util.function.BiFunction import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL @@ -36,7 +36,7 @@ import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver._ import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.network.util.{JavaUtils, TransportConf} -import org.apache.spark.util.ThreadUtils +import org.apache.spark.util.{ThreadUtils, Utils} /** * An RPC endpoint that receives registration requests from Spark drivers running on Kubernetes. @@ -69,8 +69,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( private val knownManagers = Array( "org.apache.spark.shuffle.sort.SortShuffleManager", "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") - private final val shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile() - + private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") protected override def handleMessage( message: BlockTransferMessage, @@ -250,7 +249,12 @@ private[spark] class KubernetesExternalShuffleBlockHandler( connectedApps.remove(appId) applicationRemoved(appId, false) registeredExecutors.remove(appId) - // TODO: Write removal logic + try { + val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile + driverDir.delete() + } catch { + case e: Exception => logError("Unable to delete files", e) + } } } } From bce2ed0d471390871365d9394c8e034e62dbb56b Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 7 Jan 2019 10:49:35 -0800 Subject: [PATCH 08/30] investigating issue with correctness bug --- .../external/ExternalShuffleIndexWriter.java | 71 +++++++++++++++++ .../ExternalShuffleMapOutputWriter.java | 76 +++++++++++++++++++ .../ExternalShufflePartitionReader.java | 46 +---------- .../ExternalShufflePartitionWriter.java | 8 +- .../external/ExternalShuffleReadSupport.java | 3 +- .../external/ExternalShuffleWriteSupport.java | 66 +--------------- .../KubernetesExternalShuffleService.scala | 8 +- 7 files changed, 165 insertions(+), 113 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java new file mode 100644 index 000000000000..59148a806dfb --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java @@ -0,0 +1,71 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.LongBuffer; + +public class ExternalShuffleIndexWriter { + + private final TransportClient client; + private final String appId; + private final String execId; + private final int shuffleId; + private final int mapId; + + public ExternalShuffleIndexWriter( + TransportClient client, + String appId, + String execId, + int shuffleId, + int mapId){ + this.client = client; + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + private static final Logger logger = + LoggerFactory.getLogger(ExternalShuffleIndexWriter.class); + + public void write(long[] partitionLengths) { + RpcResponseCallback callback = new RpcResponseCallback() { + @Override + public void onSuccess(ByteBuffer response) { + logger.info("Successfully uploaded index"); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Encountered an error uploading index", e); + } + }; + try { + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + logger.info("Committing all partitions with a creation of an index file"); + logger.info("Partition Lengths: " + partitionLengths[0] + ": " + partitionLengths.length); + ByteBuffer streamHeader = new UploadShuffleIndexStream( + appId, execId, shuffleId, mapId).toByteBuffer(); + // Size includes first 0L offset + ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); + LongBuffer longBuffer = byteBuffer.asLongBuffer(); + Long offset = 0L; + longBuffer.put(offset); + for (Long length: partitionLengths) { + offset += length; + longBuffer.put(offset); + } + client.uploadStream(new NioManagedBuffer(streamHeader), + new NioManagedBuffer(byteBuffer), callback); + } catch (Exception e) { + client.close(); + logger.error("Encountered error while creating transport client", e); + } + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java new file mode 100644 index 000000000000..1924fe00ed21 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -0,0 +1,76 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private final TransportClientFactory clientFactory; + private final String hostname; + private final int port; + private final String appId; + private final String execId; + private final int shuffleId; + private final int mapId; + + public ExternalShuffleMapOutputWriter( + TransportClientFactory clientFactory, + String hostname, + int port, + String appId, + String execId, + int shuffleId, + int mapId) { + this.clientFactory = clientFactory; + this.hostname = hostname; + this.port = port; + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + private static final Logger logger = + LoggerFactory.getLogger(ExternalShuffleMapOutputWriter.class); + + @Override + public ShufflePartitionWriter newPartitionWriter(int partitionId) { + try { + TransportClient client = clientFactory.createUnmanagedClient(hostname, port); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + return new ExternalShufflePartitionWriter( + client, appId, execId, shuffleId, mapId, partitionId); + } catch (Exception e) { + clientFactory.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); + } + } + + @Override + public void commitAllPartitions(long[] partitionLengths) { + try { + TransportClient client = clientFactory.createUnmanagedClient(hostname, port); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + ExternalShuffleIndexWriter externalShuffleIndexWriter = + new ExternalShuffleIndexWriter( + client, appId, execId, shuffleId, mapId); + externalShuffleIndexWriter.write(partitionLengths); + } catch (Exception e) { + clientFactory.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); // what is standard practice here? + } + } + + @Override + public void abort(Exception exception) { + clientFactory.close(); + logger.error("Encountered error while" + + "attempting to add partitions to ESS", exception); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index a31f02cd4b54..bf9cf4c29ba9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -1,6 +1,5 @@ package org.apache.spark.shuffle.external; -import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; import org.apache.spark.shuffle.api.ShufflePartitionReader; @@ -9,11 +8,8 @@ import org.slf4j.LoggerFactory; import java.io.ByteArrayInputStream; -import java.io.IOException; import java.io.InputStream; -import java.io.SequenceInputStream; import java.nio.ByteBuffer; -import java.util.Vector; public class ExternalShufflePartitionReader implements ShufflePartitionReader { @@ -43,54 +39,20 @@ public ExternalShufflePartitionReader( public InputStream fetchPartition(int reduceId) { OpenShufflePartition openMessage = new OpenShufflePartition(appId, execId, shuffleId, mapId, reduceId); - ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); - try { +// logger.info("response is: " + response.toString() + " " + response.getDouble()); if (response.hasArray()) { // use heap buffer; no array is created; only the reference is used return new ByteArrayInputStream(response.array()); } return new ByteBufferInputStream(response); } catch (Exception e) { + this.client.close(); logger.error("Encountered exception while trying to fetch blocks", e); throw new RuntimeException(e); - } - } - - private class StreamCombiningCallback implements StreamCallback { - - public boolean failed; - private final Vector inputStreams; - - private StreamCombiningCallback() { - inputStreams = new Vector<>(); - failed = false; - } - - @Override - public void onData(String streamId, ByteBuffer buf) throws IOException { - inputStreams.add(new ByteBufferInputStream(buf)); - } - - @Override - public void onComplete(String streamId) throws IOException { - // do nothing - } - - @Override - public void onFailure(String streamId, Throwable cause) throws IOException { - failed = true; - for (InputStream stream : inputStreams) { - stream.close(); - } - } - - private SequenceInputStream getCombinedInputStream() { - if (failed) { - throw new RuntimeException("Stream chunk gathering failed"); - } - return new SequenceInputStream(inputStreams.elements()); + } finally { + this.client.close(); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index a486c7b36021..4342a8e8e6f9 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -43,9 +43,7 @@ public ExternalShufflePartitionWriter( } @Override - public OutputStream openPartitionStream() { - return partitionBuffer; - } + public OutputStream openPartitionStream() { return this.partitionBuffer; } @Override public long commitAndGetTotalLength() { @@ -61,9 +59,10 @@ public void onFailure(Throwable e) { } }; try { + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); ByteBuffer streamHeader = new UploadShufflePartitionStream( - this.appId, execId, shuffleId, mapId, + appId, execId, shuffleId, mapId, partitionId).toByteBuffer(); int size = partitionBuffer.size(); byte[] buf = partitionBuffer.toByteArray(); @@ -84,6 +83,7 @@ public void onFailure(Throwable e) { @Override public void abort(Exception failureReason) { + this.client.close(); logger.error("Encountered error while attempting" + "to upload partition to ESS", failureReason); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index 7951a5318816..7a07637bf480 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -55,7 +55,8 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in TransportClient client = clientFactory.createClient(hostname, port); return new ExternalShufflePartitionReader(client, appId, execId, shuffleId, mapId); } catch (Exception e) { - logger.error("Encountered error while creating transport client"); + clientFactory.close(); + logger.error("Encountered creating transport client for partition reader"); throw new RuntimeException(e); // what is standard practice here? } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 92a605807324..1ec5de9de891 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -2,24 +2,17 @@ import com.google.common.collect.Lists; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; import org.apache.spark.network.server.NoOpRpcHandler; -import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; -import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.nio.ByteBuffer; -import java.nio.LongBuffer; import java.util.List; public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { @@ -52,61 +45,8 @@ public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, in bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } TransportClientFactory clientFactory = context.createClientFactory(bootstraps); - return new ShuffleMapOutputWriter() { - @Override - public ShufflePartitionWriter newPartitionWriter(int partitionId) { - try { - TransportClient client = clientFactory.createClient(hostname, port); - return new ExternalShufflePartitionWriter( - client, appId, execId, shuffleId, mapId, partitionId); - } catch (Exception e) { - logger.error("Encountered error while creating transport client", e); - throw new RuntimeException(e); // what is standard practice here? - } - } - - @Override - public void commitAllPartitions(long[] partitionLengths) { - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully uploaded index"); - } - - @Override - public void onFailure(Throwable e) { - logger.error("Encountered an error uploading index", e); - } - }; - final TransportClient client; - try { - client = clientFactory.createClient(hostname, port); - logger.info("Committing all partitions with a creation of an index file"); - ByteBuffer streamHeader = new UploadShuffleIndexStream( - appId, execId, shuffleId, mapId).toByteBuffer(); - // Size includes first 0L offset - ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); - LongBuffer longBuffer = byteBuffer.asLongBuffer(); - Long offset = 0L; - longBuffer.put(offset); - for (Long length: partitionLengths) { - offset += length; - longBuffer.put(offset); - } - client.uploadStream(new NioManagedBuffer(streamHeader), - new NioManagedBuffer(byteBuffer), callback); - client.close(); - } catch (Exception e) { - // Close client upon failure - logger.error("Encountered error while creating transport client", e); - } - } - - @Override - public void abort(Exception exception) { - logger.error("Encountered error while" + - "attempting to all partitions to ESS", exception); - } - }; + logger.info("Clientfactory: " + clientFactory.toString()); + return new ExternalShuffleMapOutputWriter( + clientFactory, hostname, port, appId, execId, shuffleId, mapId); } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index f77524c34855..0b64d7f6f907 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -19,13 +19,13 @@ package org.apache.spark.deploy.k8s import java.io.File import java.nio.ByteBuffer -import java.nio.file.{Files, Paths} +import java.nio.file.Paths import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} import java.util.function.BiFunction import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} - import scala.collection.JavaConverters._ + import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL @@ -48,6 +48,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( indexCacheSize: String) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { + ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervals, TimeUnit.SECONDS) @@ -205,7 +206,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( s"App is not registered for remote shuffle (appId=$appId, execId=$execId)") } ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, - "shuffle_" + shuffleId + "_" + mapId + "_0." + extension) + s"shuffle_${shuffleId}_${mapId}_0.$extension") } /** An extractor object for matching BlockTransferMessages. */ @@ -251,6 +252,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( registeredExecutors.remove(appId) try { val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile + logInfo(s"Driver dir is: ${driverDir.getAbsolutePath}") driverDir.delete() } catch { case e: Exception => logError("Unable to delete files", e) From 28714b32d02f6341b3899d90792df0ec27571cea Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Mon, 7 Jan 2019 18:39:42 -0800 Subject: [PATCH 09/30] refactored executor specific logic and began fixing transport client issues --- .../shuffle/FileWriterStreamCallback.java | 17 ++-- .../k8s/KubernetesExternalShuffleClient.java | 37 ------- .../protocol/BlockTransferMessage.java | 3 +- .../protocol/OpenShufflePartition.java | 15 +-- .../RegisterExecutorWithExternal.java | 90 ----------------- .../protocol/UploadShuffleIndexStream.java | 14 +-- .../UploadShufflePartitionStream.java | 14 +-- .../external/ExternalShuffleDataIO.java | 9 +- .../external/ExternalShuffleIndexWriter.java | 30 ++++-- .../ExternalShuffleMapOutputWriter.java | 24 ++--- .../ExternalShufflePartitionReader.java | 36 ++++--- .../ExternalShufflePartitionWriter.java | 52 +++++++--- .../external/ExternalShuffleReadSupport.java | 18 ++-- .../external/ExternalShuffleWriteSupport.java | 8 +- .../apache/spark/storage/BlockManager.scala | 7 -- .../KubernetesExternalShuffleService.scala | 99 +++++-------------- 16 files changed, 144 insertions(+), 329 deletions(-) delete mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java index 6ca1292efb6b..16f45f4cc292 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -33,7 +33,7 @@ public String toString() { } } - private final ExternalShuffleBlockResolver.AppExecId fullExecId; + private final String appId; private final int shuffleId; private final int mapId; private final File file; @@ -41,12 +41,12 @@ public String toString() { private WritableByteChannel fileOutputChannel = null; public FileWriterStreamCallback( - ExternalShuffleBlockResolver.AppExecId fullExecId, + String appId, int shuffleId, int mapId, File file, FileWriterStreamCallback.FileType fileType) { - this.fullExecId = fullExecId; + this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; this.file = file; @@ -55,7 +55,7 @@ public FileWriterStreamCallback( public void open() { logger.info( - "Opening {} for backup writing. File type: {}", file.getAbsolutePath(), fileType); + "Opening {} for remote writing. File type: {}", file.getAbsolutePath(), fileType); if (fileOutputChannel != null) { throw new IllegalStateException( String.format( @@ -101,9 +101,8 @@ public void open() { @Override public String getID() { - return String.format("%s-%s-%d-%d-%s", - fullExecId.appId, - fullExecId.execId, + return String.format("%s-%d-%d-%s", + appId, shuffleId, mapId, fileType); @@ -124,7 +123,7 @@ public void onComplete(String streamId) throws IOException { @Override public void onFailure(String streamId, Throwable cause) throws IOException { - logger.warn("Failed to back up shuffle file at {} (type: %s).", + logger.warn("Failed to write shuffle file at {} (type: %s).", file.getAbsolutePath(), fileType, cause); @@ -132,7 +131,7 @@ public void onFailure(String streamId, Throwable cause) throws IOException { // TODO delete parent dirs too if (!file.delete()) { logger.warn( - "Failed to delete incomplete backup shuffle file at %s (type: %s)", + "Failed to delete incomplete remote shuffle file at %s (type: %s)", file.getAbsolutePath(), fileType); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java index b145c0d5e8bd..06af3bd141fb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/k8s/KubernetesExternalShuffleClient.java @@ -24,7 +24,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.apache.spark.network.shuffle.protocol.RegisterExecutorWithExternal; import org.apache.spark.network.shuffle.protocol.ShuffleServiceHeartbeat; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -80,20 +79,6 @@ public void registerDriverWithShuffleService( client.sendRpc(registerDriver, new RegisterDriverCallback(client, heartbeatIntervalMs)); } - public void registerExecutorWithShuffleService( - String host, - int port, - String appId, - String execId, - String shuffleManager) throws IOException, InterruptedException { - checkInit(); - ByteBuffer registerExecutor = - new RegisterExecutorWithExternal(appId, execId, shuffleManager).toByteBuffer(); - logger.info("Registering with external shuffle service for " + appId + ":" + execId); - TransportClient client = clientFactory.createClient(host, port); - client.sendRpc(registerExecutor, new RegisterExecutorCallback(appId, execId)); - } - private class RegisterDriverCallback implements RpcResponseCallback { private final TransportClient client; private final long heartbeatIntervalMs; @@ -117,28 +102,6 @@ public void onFailure(Throwable e) { } } - private class RegisterExecutorCallback implements RpcResponseCallback { - private String appId; - private String execId; - - private RegisterExecutorCallback(String appId, String execId) { - this.appId = appId; - this.execId = execId; - } - - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully registered " + - appId + ":" + execId + " with external shuffle service."); - } - - @Override - public void onFailure(Throwable e) { - logger.warn("Unable to register " + - appId + ":" + execId + " with external shuffle service, " + e); - } - } - @Override public void close() { heartbeaterThread.shutdownNow(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index b52a9f632e3b..b2b0f3f9796c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -41,7 +41,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), - UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9), REGISTER_EXECUTOR_WITH_EXTERNAL(10); + UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9); private final byte id; @@ -70,7 +70,6 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 7: return UploadShufflePartitionStream.decode(buf); case 8: return UploadShuffleIndexStream.decode(buf); case 9: return OpenShufflePartition.decode(buf); - case 10: return RegisterExecutorWithExternal.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java index 78f6f5dc5910..63d2387bd6d1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenShufflePartition.java @@ -9,15 +9,13 @@ public class OpenShufflePartition extends BlockTransferMessage { public final String appId; - public final String execId; public final int shuffleId; public final int mapId; public final int partitionId; public OpenShufflePartition( - String appId, String execId, int shuffleId, int mapId, int partitionId) { + String appId, int shuffleId, int mapId, int partitionId) { this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; @@ -28,7 +26,6 @@ public boolean equals(Object other) { if (other != null && other instanceof OpenShufflePartition) { OpenShufflePartition o = (OpenShufflePartition) other; return Objects.equal(appId, o.appId) - && execId == o.execId && shuffleId == o.shuffleId && mapId == o.mapId && partitionId == o.partitionId; @@ -43,14 +40,13 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(appId, execId, shuffleId, mapId, partitionId); + return Objects.hashCode(appId, shuffleId, mapId, partitionId); } @Override public String toString() { return Objects.toStringHelper(this) .add("appId", appId) - .add("execId", execId) .add("shuffleId", shuffleId) .add("mapId", mapId) .add("partitionId", partitionId) @@ -59,14 +55,12 @@ public String toString() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + - Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); buf.writeInt(mapId); buf.writeInt(partitionId); @@ -74,10 +68,9 @@ public void encode(ByteBuf buf) { public static OpenShufflePartition decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); int partitionId = buf.readInt(); - return new OpenShufflePartition(appId, execId, shuffleId, mapId, partitionId); + return new OpenShufflePartition(appId, shuffleId, mapId, partitionId); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java deleted file mode 100644 index 39bfb95b4af3..000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterExecutorWithExternal.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.network.shuffle.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -// Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - -public class RegisterExecutorWithExternal extends BlockTransferMessage { - - public final String appId; - public final String execId; - public final String shuffleManager; - - public RegisterExecutorWithExternal( - String appId, String execId, String shuffleManager) { - this.appId = appId; - this.execId = execId; - this.shuffleManager = shuffleManager; - } - - @Override - protected Type type() { - return Type.REGISTER_EXECUTOR_WITH_EXTERNAL; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + Encoders.Strings.encodedLength(shuffleManager); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - Encoders.Strings.encode(buf, shuffleManager); - } - - @Override - public boolean equals(Object other) { - if (other instanceof RegisterExecutorWithExternal) { - RegisterExecutorWithExternal o = (RegisterExecutorWithExternal) other; - return Objects.equal(appId, o.appId) - && Objects.equal(execId, o.execId) - && Objects.equal(shuffleManager, o.shuffleManager); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hashCode(appId, execId, shuffleManager); - } - - @Override - public String toString() { - return Objects.toStringHelper(RegisterExecutorWithExternal.class) - .add("appId", appId) - .add("execId", execId) - .add("shuffleManager", shuffleManager) - .toString(); - } - - public static RegisterExecutorWithExternal decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); - String shuffleManager = Encoders.Strings.decode(buf); - return new RegisterExecutorWithExternal(appId, execId, shuffleManager); - } -} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java index 04ff66a986a0..ffa7ee36881c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java @@ -29,17 +29,14 @@ */ public class UploadShuffleIndexStream extends BlockTransferMessage { public final String appId; - public final String execId; public final int shuffleId; public final int mapId; public UploadShuffleIndexStream( String appId, - String execId, int shuffleId, int mapId) { this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; } @@ -49,7 +46,6 @@ public boolean equals(Object other) { if (other != null && other instanceof UploadShufflePartitionStream) { UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; return Objects.equal(appId, o.appId) - && execId == o.execId && shuffleId == o.shuffleId && mapId == o.mapId; } @@ -63,14 +59,13 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(appId, execId, shuffleId, mapId); + return Objects.hashCode(appId, shuffleId, mapId); } @Override public String toString() { return Objects.toStringHelper(this) .add("appId", appId) - .add("execId", execId) .add("shuffleId", shuffleId) .add("mapId", mapId) .toString(); @@ -78,23 +73,20 @@ public String toString() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + - Encoders.Strings.encodedLength(execId) + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); buf.writeInt(mapId); } public static UploadShuffleIndexStream decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); - return new UploadShuffleIndexStream(appId, execId, shuffleId, mapId); + return new UploadShuffleIndexStream(appId, shuffleId, mapId); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java index ee5629e74526..f0506cc08feb 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -29,19 +29,16 @@ */ public class UploadShufflePartitionStream extends BlockTransferMessage { public final String appId; - public final String execId; public final int shuffleId; public final int mapId; public final int partitionId; public UploadShufflePartitionStream( String appId, - String execId, int shuffleId, int mapId, int partitionId) { this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; @@ -52,7 +49,6 @@ public boolean equals(Object other) { if (other != null && other instanceof UploadShufflePartitionStream) { UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; return Objects.equal(appId, o.appId) - && execId == o.execId && shuffleId == o.shuffleId && mapId == o.mapId && partitionId == o.partitionId; @@ -67,14 +63,13 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(appId, execId, shuffleId, mapId, partitionId); + return Objects.hashCode(appId, shuffleId, mapId, partitionId); } @Override public String toString() { return Objects.toStringHelper(this) .add("appId", appId) - .add("execId", execId) .add("shuffleId", shuffleId) .add("mapId", mapId) .toString(); @@ -82,14 +77,12 @@ public String toString() { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + - Encoders.Strings.encodedLength(execId) + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); buf.writeInt(shuffleId); buf.writeInt(mapId); buf.writeInt(partitionId); @@ -97,10 +90,9 @@ public void encode(ByteBuf buf) { public static UploadShufflePartitionStream decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - String execId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); int partitionId = buf.readInt(); - return new UploadShufflePartitionStream(appId, execId, shuffleId, mapId, partitionId); + return new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 2500711b69fd..da35ac76f343 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -23,18 +23,15 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private final SecurityManager securityManager; private final String hostname; private final int port; - private final String execId; public ExternalShuffleDataIO( SparkConf sparkConf) { this.sparkConf = sparkConf; - this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 2); + this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); this.securityManager = sparkEnv.securityManager(); this.hostname = blockManager.getRandomShuffleHost(); this.port = blockManager.getRandomShufflePort(); - - this.execId = blockManager.shuffleServerId().executorId(); } @Override @@ -46,13 +43,13 @@ public void initialize() { public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port, execId); + securityManager, hostname, port); } @Override public ShuffleWriteSupport writeSupport() { return new ExternalShuffleWriteSupport( conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port, execId); + securityManager, hostname, port); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java index 59148a806dfb..6983d061289f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java @@ -3,6 +3,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -12,21 +13,24 @@ public class ExternalShuffleIndexWriter { - private final TransportClient client; + private final TransportClientFactory clientFactory; + private final String hostName; + private final int port; private final String appId; - private final String execId; private final int shuffleId; private final int mapId; public ExternalShuffleIndexWriter( - TransportClient client, + TransportClientFactory clientFactory, + String hostName, + int port, String appId, - String execId, int shuffleId, int mapId){ - this.client = client; + this.clientFactory = clientFactory; + this.hostName = hostName; + this.port = port; this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; } @@ -46,12 +50,13 @@ public void onFailure(Throwable e) { logger.error("Encountered an error uploading index", e); } }; + TransportClient client = null; try { - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); logger.info("Committing all partitions with a creation of an index file"); - logger.info("Partition Lengths: " + partitionLengths[0] + ": " + partitionLengths.length); + logger.info("Partition Lengths: " + partitionLengths.length + ": " + + partitionLengths[0] + "," + partitionLengths[1]); ByteBuffer streamHeader = new UploadShuffleIndexStream( - appId, execId, shuffleId, mapId).toByteBuffer(); + appId, shuffleId, mapId).toByteBuffer(); // Size includes first 0L offset ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); LongBuffer longBuffer = byteBuffer.asLongBuffer(); @@ -61,11 +66,18 @@ public void onFailure(Throwable e) { offset += length; longBuffer.put(offset); } + client = clientFactory.createUnmanagedClient(hostName, port); + client.setClientId(String.format("index-%s-%d-%d", appId, shuffleId, mapId)); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); client.uploadStream(new NioManagedBuffer(streamHeader), new NioManagedBuffer(byteBuffer), callback); } catch (Exception e) { client.close(); logger.error("Encountered error while creating transport client", e); + } finally { + if (client != null) { + client.close(); + } } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 1924fe00ed21..786a56d46482 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -1,6 +1,5 @@ package org.apache.spark.shuffle.external; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; @@ -10,26 +9,23 @@ public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final TransportClientFactory clientFactory; - private final String hostname; + private final String hostName; private final int port; private final String appId; - private final String execId; private final int shuffleId; private final int mapId; public ExternalShuffleMapOutputWriter( TransportClientFactory clientFactory, - String hostname, + String hostName, int port, String appId, - String execId, int shuffleId, int mapId) { this.clientFactory = clientFactory; - this.hostname = hostname; + this.hostName = hostName; this.port = port; this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; } @@ -40,10 +36,8 @@ public ExternalShuffleMapOutputWriter( @Override public ShufflePartitionWriter newPartitionWriter(int partitionId) { try { - TransportClient client = clientFactory.createUnmanagedClient(hostname, port); - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); - return new ExternalShufflePartitionWriter( - client, appId, execId, shuffleId, mapId, partitionId); + return new ExternalShufflePartitionWriter(clientFactory, + hostName, port, appId, shuffleId, mapId, partitionId); } catch (Exception e) { clientFactory.close(); logger.error("Encountered error while creating transport client", e); @@ -54,15 +48,13 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { @Override public void commitAllPartitions(long[] partitionLengths) { try { - TransportClient client = clientFactory.createUnmanagedClient(hostname, port); - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); ExternalShuffleIndexWriter externalShuffleIndexWriter = - new ExternalShuffleIndexWriter( - client, appId, execId, shuffleId, mapId); + new ExternalShuffleIndexWriter(clientFactory, + hostName, port, appId, shuffleId, mapId); externalShuffleIndexWriter.write(partitionLengths); } catch (Exception e) { clientFactory.close(); - logger.error("Encountered error while creating transport client", e); + logger.error("Encountered error writing index file", e); throw new RuntimeException(e); // what is standard practice here? } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index bf9cf4c29ba9..a83016e72fb8 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -1,6 +1,7 @@ package org.apache.spark.shuffle.external; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.util.ByteBufferInputStream; @@ -16,21 +17,24 @@ public class ExternalShufflePartitionReader implements ShufflePartitionReader { private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionReader.class); - private final TransportClient client; + private final TransportClientFactory clientFactory; + private final String hostName; + private final int port; private final String appId; - private final String execId; private final int shuffleId; private final int mapId; public ExternalShufflePartitionReader( - TransportClient client, + TransportClientFactory clientFactory, + String hostName, + int port, String appId, - String execId, int shuffleId, int mapId) { - this.client = client; + this.clientFactory = clientFactory; + this.hostName = hostName; + this.port = port; this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; } @@ -38,21 +42,31 @@ public ExternalShufflePartitionReader( @Override public InputStream fetchPartition(int reduceId) { OpenShufflePartition openMessage = - new OpenShufflePartition(appId, execId, shuffleId, mapId, reduceId); - ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); + new OpenShufflePartition(appId, shuffleId, mapId, reduceId); + TransportClient client = null; try { -// logger.info("response is: " + response.toString() + " " + response.getDouble()); + client = clientFactory.createUnmanagedClient(hostName, port); + client.setClientId(String.format( + "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId)); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); + logger.info("response is: " + response.toString() + + " " + response.array() + " " + response.hasArray()); if (response.hasArray()) { // use heap buffer; no array is created; only the reference is used return new ByteArrayInputStream(response.array()); } return new ByteBufferInputStream(response); } catch (Exception e) { - this.client.close(); + if (client != null) { + client.close(); + } logger.error("Encountered exception while trying to fetch blocks", e); throw new RuntimeException(e); } finally { - this.client.close(); + if (client != null) { + client.close(); + } } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 4342a8e8e6f9..75608d3bc3ee 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -4,6 +4,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.slf4j.Logger; @@ -17,9 +18,10 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); - private final TransportClient client; + private final TransportClientFactory clientFactory; + private final String hostName; + private final int port; private final String appId; - private final String execId; private final int shuffleId; private final int mapId; private final int partitionId; @@ -28,22 +30,24 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private final ByteArrayOutputStream partitionBuffer = new ByteArrayOutputStream(); public ExternalShufflePartitionWriter( - TransportClient client, + TransportClientFactory clientFactory, + String hostName, + int port, String appId, - String execId, int shuffleId, int mapId, int partitionId) { - this.client = client; + this.clientFactory = clientFactory; + this.hostName = hostName; + this.port = port; this.appId = appId; - this.execId = execId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; } @Override - public OutputStream openPartitionStream() { return this.partitionBuffer; } + public OutputStream openPartitionStream() { return partitionBuffer; } @Override public long commitAndGetTotalLength() { @@ -58,24 +62,35 @@ public void onFailure(Throwable e) { logger.error("Encountered an error uploading partition", e); } }; + TransportClient client = null; try { - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); - ByteBuffer streamHeader = - new UploadShufflePartitionStream( - appId, execId, shuffleId, mapId, - partitionId).toByteBuffer(); + ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, + partitionId).toByteBuffer(); int size = partitionBuffer.size(); + partitionBuffer.flush(); byte[] buf = partitionBuffer.toByteArray(); - ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); + client = clientFactory.createUnmanagedClient(hostName, port); + client.setClientId(String.format("data-%s-%d-%d-%d", + appId, shuffleId, mapId, partitionId)); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); totalLength += size; } catch (Exception e) { - client.close(); + if (client != null) { + client.close(); + } logger.error("Encountered error while attempting to upload partition to ESS", e); throw new RuntimeException(e); } finally { - client.close(); + if (client != null) { + client.close(); + } + try { + partitionBuffer.close(); + } catch(Exception e) { + logger.error("Failed to close streams", e); + } logger.info("Successfully sent partition to ESS"); } return totalLength; @@ -83,7 +98,12 @@ public void onFailure(Throwable e) { @Override public void abort(Exception failureReason) { - this.client.close(); + clientFactory.close(); + try { + this.partitionBuffer.close(); + } catch(IOException e) { + logger.error("Failed to close streams after failing to upload partition", e); + } logger.error("Encountered error while attempting" + "to upload partition to ESS", failureReason); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index 7a07637bf480..ddff937d47c2 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -2,7 +2,6 @@ import com.google.common.collect.Lists; import org.apache.spark.network.TransportContext; -import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; @@ -23,37 +22,34 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; - private final String hostname; + private final String hostName; private final int port; - private final String execId; public ExternalShuffleReadSupport( TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostname, - int port, - String execId) { + String hostName, + int port) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; - this.hostname = hostname; + this.hostName = hostName; this.port = port; - this.execId = execId; } @Override public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) { // TODO combine this into a function with ExternalShuffleWriteSupport - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } TransportClientFactory clientFactory = context.createClientFactory(bootstraps); try { - TransportClient client = clientFactory.createClient(hostname, port); - return new ExternalShufflePartitionReader(client, appId, execId, shuffleId, mapId); + return new ExternalShufflePartitionReader(clientFactory, + hostName, port, appId, shuffleId, mapId); } catch (Exception e) { clientFactory.close(); logger.error("Encountered creating transport client for partition reader"); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 1ec5de9de891..4754c58f136b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -24,22 +24,20 @@ public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { private final SecretKeyHolder secretKeyHolder; private final String hostname; private final int port; - private final String execId; public ExternalShuffleWriteSupport( TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostname, int port, String execId) { + String hostname, int port) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostname = hostname; this.port = port; - this.execId = execId; } @Override public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); @@ -47,6 +45,6 @@ public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, in TransportClientFactory clientFactory = context.createClientFactory(bootstraps); logger.info("Clientfactory: " + clientFactory.toString()); return new ExternalShuffleMapOutputWriter( - clientFactory, hostname, port, appId, execId, shuffleId, mapId); + clientFactory, hostname, port, appId, shuffleId, mapId); } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index daa52525b154..d5bb04b13e9b 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -296,13 +296,6 @@ private[spark] class BlockManager( s"${conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) } - } else if (externalk8sShuffleServiceEnabled && !blockManagerId.isDriver) { - remoteShuffleServiceAddress.foreach { ssId => - shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] - .registerExecutorWithShuffleService( - ssId._1, ssId._2, appId, - shuffleServerId.executorId, shuffleManager.getClass.getName) - } } else if (externalNonK8sShuffleService && !blockManagerId.isDriver) { // Register Executors' configuration with the local shuffle service, if one should exist. registerWithExternalShuffleServer() diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 0b64d7f6f907..74cc43999543 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -21,7 +21,6 @@ import java.io.File import java.nio.ByteBuffer import java.nio.file.Paths import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} -import java.util.function.BiFunction import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} import scala.collection.JavaConverters._ @@ -54,8 +53,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( // Stores a map of app id to app state (timeout value and last heartbeat) private val connectedApps = new ConcurrentHashMap[String, AppState]() - private val registeredExecutors = - new ConcurrentHashMap[String, Map[String, ExecutorShuffleInfo]]() private val indexCacheLoader = new CacheLoader[File, ShuffleIndexInformation]() { override def load(file: File): ShuffleIndexInformation = new ShuffleIndexInformation(file) } @@ -67,9 +64,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( }) .build(indexCacheLoader) - private val knownManagers = Array( - "org.apache.spark.shuffle.sort.SortShuffleManager", - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager") + private val knownManagers = Array("org.apache.spark.shuffle.sort.SortShuffleManager") private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") protected override def handleMessage( @@ -77,34 +72,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( client: TransportClient, callback: RpcResponseCallback): Unit = { message match { - case RegisterExecutorParam(appId, execId, shuffleManager) => - val fullId = new AppExecId(appId, execId) - if (registeredExecutors.containsKey(fullId)) { - throw new UnsupportedOperationException(s"Executor $fullId cannot be registered twice") - } - val executorDir = Paths.get(shuffleDir.getAbsolutePath, appId, execId).toFile - if (!executorDir.mkdir()) { - throw new RuntimeException(s"Failed to create dir ${executorDir.getAbsolutePath}") - } - if (!knownManagers.contains(shuffleManager)) { - throw new UnsupportedOperationException(s"Unsupported shuffle manager of exec: ${fullId}") - } - val executorShuffleInfo = new ExecutorShuffleInfo( - Array(executorDir.getAbsolutePath), 1, shuffleManager) - val execMap = Map(execId -> executorShuffleInfo) - registeredExecutors.merge(appId, execMap, - new BiFunction[ - Map[String, ExecutorShuffleInfo], - Map[String, ExecutorShuffleInfo], - Map[String, ExecutorShuffleInfo]]() { - override def apply( - t: Map[String, ExecutorShuffleInfo], u: Map[String, ExecutorShuffleInfo]): - Map[String, ExecutorShuffleInfo] = { - t ++ u - } - }) - logInfo(s"Registering executor ${fullId} with ${executorShuffleInfo}") - case RegisterDriverParam(appId, appState) => val address = client.getSocketAddress val timeout = appState.heartbeatTimeout @@ -119,7 +86,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( throw new RuntimeException(s"Failed to create dir ${driverDir.getAbsolutePath}") } connectedApps.put(appId, appState) - registeredExecutors.put(appId, Map[String, ExecutorShuffleInfo]()) callback.onSuccess(ByteBuffer.allocate(0)) case Heartbeat(appId) => @@ -133,17 +99,15 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + s"address $address, appId '$appId').") } - case OpenParam(appId, execId, shuffleId, mapId, partitionId) => - logInfo(s"Received open param from app $appId from $execId") - val indexFile = getFile( - appId, execId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) + case OpenParam(appId, shuffleId, mapId, partitionId) => + logInfo(s"Received open param from app $appId") + val indexFile = getFile(appId, shuffleId, mapId, "index") try { val shuffleIndexInformation = shuffleIndexCache.get(indexFile) val shuffleIndexRecord = shuffleIndexInformation.getIndex(partitionId) val managedBuffer = new FileSegmentManagedBuffer( transportConf, - getFile(appId, execId, shuffleId, mapId, - "data", FileWriterStreamCallback.FileType.DATA), + getFile(appId, shuffleId, mapId, "data"), shuffleIndexRecord.getOffset, shuffleIndexRecord.getLength) callback.onSuccess(managedBuffer.nioByteBuffer()) @@ -160,15 +124,15 @@ private[spark] class KubernetesExternalShuffleBlockHandler( callback: RpcResponseCallback): StreamCallbackWithID = { header match { case UploadParam( - appId, execId, shuffleId, mapId, partitionId) => + appId, shuffleId, mapId, partitionId) => // TODO: Investigate whether we should use the partitionId for Index File creation - logInfo(s"Received upload param from app $appId from $execId") + logInfo(s"Received upload param from app $appId") getFileWriterStreamCallback( - appId, execId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) - case UploadIndexParam(appId, execId, shuffleId, mapId) => - logInfo(s"Received upload index param from app $appId from $execId") + appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + case UploadIndexParam(appId, shuffleId, mapId) => + logInfo(s"Received upload index param from app $appId") getFileWriterStreamCallback( - appId, execId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) + appId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) case _ => super.handleStream(header, client, callback) } @@ -176,37 +140,24 @@ private[spark] class KubernetesExternalShuffleBlockHandler( private def getFileWriterStreamCallback( appId: String, - execId: String, shuffleId: Int, mapId: Int, extension: String, fileType: FileWriterStreamCallback.FileType): StreamCallbackWithID = { - val file = getFile(appId, execId, shuffleId, mapId, extension, fileType) + val file = getFile(appId, shuffleId, mapId, extension) val streamCallback = - new FileWriterStreamCallback(new AppExecId(appId, execId), shuffleId, mapId, file, fileType) + new FileWriterStreamCallback(appId, shuffleId, mapId, file, fileType) streamCallback.open() streamCallback } private def getFile( appId: String, - execId: String, shuffleId: Int, mapId: Int, - extension: String, - fileType: FileWriterStreamCallback.FileType): File = { - val execMap = registeredExecutors.get(appId) - if (execMap == null) { - throw new RuntimeException( - s"appId=$appId is not registered for remote shuffle") - } - val executor = execMap(execId) - if (executor == null) { - throw new RuntimeException( - s"App is not registered for remote shuffle (appId=$appId, execId=$execId)") - } - ExternalShuffleBlockResolver.getFile(executor.localDirs, executor.subDirsPerLocalDir, - s"shuffle_${shuffleId}_${mapId}_0.$extension") + extension: String): File = { + Paths.get(shuffleDir.getAbsolutePath, appId, + s"shuffle_${shuffleId}_${mapId}_0.$extension").toFile } /** An extractor object for matching BlockTransferMessages. */ @@ -220,23 +171,18 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } private object UploadParam { - def unapply(u: UploadShufflePartitionStream): Option[(String, String, Int, Int, Int)] = - Some((u.appId, u.execId, u.shuffleId, u.mapId, u.partitionId)) + def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId, u.partitionId)) } private object UploadIndexParam { - def unapply(u: UploadShuffleIndexStream): Option[(String, String, Int, Int)] = - Some((u.appId, u.execId, u.shuffleId, u.mapId)) - } - - private object RegisterExecutorParam { - def unapply(e: RegisterExecutorWithExternal): Option[(String, String, String)] = - Some((e.appId, e.execId, e.shuffleManager)) + def unapply(u: UploadShuffleIndexStream): Option[(String, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId)) } private object OpenParam { - def unapply(o: OpenShufflePartition): Option[(String, String, Int, Int, Int)] = - Some((o.appId, o.execId, o.shuffleId, o.mapId, o.partitionId)) + def unapply(o: OpenShufflePartition): Option[(String, Int, Int, Int)] = + Some((o.appId, o.shuffleId, o.mapId, o.partitionId)) } private class AppState(val heartbeatTimeout: Long, @volatile var lastHeartbeat: Long) @@ -249,7 +195,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logInfo(s"Application $appId timed out. Removing shuffle files.") connectedApps.remove(appId) applicationRemoved(appId, false) - registeredExecutors.remove(appId) try { val driverDir = Paths.get(shuffleDir.getAbsolutePath, appId).toFile logInfo(s"Driver dir is: ${driverDir.getAbsolutePath}") From d598e003962f92910233449c631d0ec7b9e73e7b Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 8 Jan 2019 13:04:45 -0800 Subject: [PATCH 10/30] remove client issues --- .../shuffle/FileWriterStreamCallback.java | 2 ++ .../external/ExternalShuffleIndexWriter.java | 4 ---- .../external/ExternalShufflePartitionReader.java | 4 ---- .../external/ExternalShufflePartitionWriter.java | 16 ++++------------ .../k8s/KubernetesExternalShuffleService.scala | 1 - 5 files changed, 6 insertions(+), 21 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java index 16f45f4cc292..1f44ae8b3c78 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -118,6 +118,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { + logger.info( + "Finished writing {}. File type: {}", file.getAbsolutePath(), fileType); fileOutputChannel.close(); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java index 6983d061289f..946060785596 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java @@ -74,10 +74,6 @@ public void onFailure(Throwable e) { } catch (Exception e) { client.close(); logger.error("Encountered error while creating transport client", e); - } finally { - if (client != null) { - client.close(); - } } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index a83016e72fb8..c00761e74efc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -63,10 +63,6 @@ public InputStream fetchPartition(int reduceId) { } logger.error("Encountered exception while trying to fetch blocks", e); throw new RuntimeException(e); - } finally { - if (client != null) { - client.close(); - } } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 75608d3bc3ee..100fcc3d2a62 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -53,8 +53,7 @@ public ExternalShufflePartitionWriter( public long commitAndGetTotalLength() { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully uploaded partition"); + public void onSuccess(ByteBuffer response) { logger.info("Successfully uploaded partition"); } @Override @@ -66,9 +65,8 @@ public void onFailure(Throwable e) { try { ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId).toByteBuffer(); - int size = partitionBuffer.size(); - partitionBuffer.flush(); byte[] buf = partitionBuffer.toByteArray(); + int size = buf.length; ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); client = clientFactory.createUnmanagedClient(hostName, port); client.setClientId(String.format("data-%s-%d-%d-%d", @@ -76,6 +74,8 @@ public void onFailure(Throwable e) { logger.info("clientid: " + client.getClientId() + " " + client.isActive()); client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); totalLength += size; + logger.info("Partition Length: " + totalLength); + logger.info("Size: " + size); } catch (Exception e) { if (client != null) { client.close(); @@ -83,14 +83,6 @@ public void onFailure(Throwable e) { logger.error("Encountered error while attempting to upload partition to ESS", e); throw new RuntimeException(e); } finally { - if (client != null) { - client.close(); - } - try { - partitionBuffer.close(); - } catch(Exception e) { - logger.error("Failed to close streams", e); - } logger.info("Successfully sent partition to ESS"); } return totalLength; diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 74cc43999543..b90918651b96 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -64,7 +64,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( }) .build(indexCacheLoader) - private val knownManagers = Array("org.apache.spark.shuffle.sort.SortShuffleManager") private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") protected override def handleMessage( From 90f3804d773ad6b5cc7f5f2f3640e53bb6c08d4e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 8 Jan 2019 16:17:39 -0800 Subject: [PATCH 11/30] added hashcode --- .../shuffle/external/ExternalShufflePartitionReader.java | 2 ++ .../shuffle/external/ExternalShufflePartitionWriter.java | 2 ++ .../org/apache/spark/examples/GroupByShuffleTest.scala | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index c00761e74efc..08d079aadd9a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -11,6 +11,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.ByteBuffer; +import java.util.Arrays; public class ExternalShufflePartitionReader implements ShufflePartitionReader { @@ -53,6 +54,7 @@ public InputStream fetchPartition(int reduceId) { logger.info("response is: " + response.toString() + " " + response.array() + " " + response.hasArray()); if (response.hasArray()) { + logger.info("response hashcode: " + Arrays.hashCode(response.array())); // use heap buffer; no array is created; only the reference is used return new ByteArrayInputStream(response.array()); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 100fcc3d2a62..30930f855882 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -12,6 +12,7 @@ import java.io.*; import java.nio.ByteBuffer; +import java.util.Arrays; public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { @@ -72,6 +73,7 @@ public void onFailure(Throwable e) { client.setClientId(String.format("data-%s-%d-%d-%d", appId, shuffleId, mapId, partitionId)); logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + logger.info("THE BUFFER HASH CODE IS: " + Arrays.hashCode(buf)); client.uploadStream(new NioManagedBuffer(streamHeader), managedBuffer, callback); totalLength += size; logger.info("Partition Length: " + totalLength); diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala index 9d056a9f6f7b..0ce0d3bec6cd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -42,6 +42,15 @@ object GroupByShuffleTest { println(wordCountsWithGroup.mkString(",")) + val wordPairsRDD2 = spark.sparkContext.parallelize(words, 1).map(word => (word, 1)) + + val wordCountsWithGroup2 = wordPairsRDD2 + .groupByKey() + .map(t => (t._1, t._2.sum)) + .collect() + + println(wordCountsWithGroup2.mkString(",")) + spark.stop() } } From 7f307519785032a8df4a0641d666a8178c37e330 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 9 Jan 2019 11:25:01 -0800 Subject: [PATCH 12/30] small changes to replica-based shuffle service implementation --- .../ExternalShufflePartitionReader.java | 17 +++++++++++++++-- .../org/apache/spark/storage/BlockManager.scala | 5 ++--- .../storage/ShuffleBlockFetcherIterator.scala | 1 - .../k8s/KubernetesExternalShuffleService.scala | 2 +- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index 08d079aadd9a..adcb8c435fad 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -4,15 +4,20 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.storage.ShuffleBlockFetcherIterator; import org.apache.spark.util.ByteBufferInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Arrays; +import static org.apache.hadoop.hive.ql.exec.MapredContext.close; + public class ExternalShufflePartitionReader implements ShufflePartitionReader { private static final Logger logger = @@ -55,10 +60,18 @@ public InputStream fetchPartition(int reduceId) { " " + response.array() + " " + response.hasArray()); if (response.hasArray()) { logger.info("response hashcode: " + Arrays.hashCode(response.array())); + ByteArrayInputStream responseStream = + new ByteArrayInputStream(response.array()); + logger.info(String.format( + "Stream info %d %d", + responseStream.available(), + responseStream.read())); // use heap buffer; no array is created; only the reference is used - return new ByteArrayInputStream(response.array()); + return new DataInputStream(responseStream); } - return new ByteBufferInputStream(response); + ByteBufferInputStream responseStream = + new ByteBufferInputStream(response); + return new DataInputStream(responseStream); } catch (Exception e) { if (client != null) { client.close(); diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index d5bb04b13e9b..fb1ed02c857a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -275,9 +275,8 @@ private[spark] class BlockManager( shuffleServerId = if (externalk8sShuffleServiceEnabled) { // TODO: Investigate better methods of load balancing - // note: might break if retry (as exec could write to one of the addresses - // it did not write to - randomShuffleServiceAddress = Random.shuffle(remoteShuffleServiceAddress).head + // note: might break if re-initialized + randomShuffleServiceAddress = remoteShuffleServiceAddress.head BlockManagerId(executorId, randomShuffleServiceAddress._1, randomShuffleServiceAddress._2) } else if (externalNonK8sShuffleService) { logInfo(s"external shuffle service port = $externalShuffleServicePort") diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 86f7c08eddcb..cf8e4793c448 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -437,7 +437,6 @@ final class ShuffleBlockFetcherIterator( s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" throwFetchFailedException(blockId, address, new IOException(msg)) } - val in = try { buf.createInputStream() } catch { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index b90918651b96..5f943f15c073 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -31,8 +31,8 @@ import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEA import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.FileSegmentManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} +import org.apache.spark.network.server.OneForOneStreamManager import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver._ import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.network.util.{JavaUtils, TransportConf} import org.apache.spark.util.{ThreadUtils, Utils} From cffc20ce2a6f60a767b2604260a00bb5983fe0c8 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 9 Jan 2019 15:56:04 -0800 Subject: [PATCH 13/30] solved read issue in terms of deserialization --- .../ExternalShufflePartitionReader.java | 27 ++++++------------- .../ExternalShufflePartitionWriter.java | 3 ++- .../shuffle/BlockStoreShuffleReader.scala | 6 +++-- .../KubernetesExternalShuffleService.scala | 1 - 4 files changed, 14 insertions(+), 23 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index adcb8c435fad..ee363f8bb41b 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -4,20 +4,14 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; import org.apache.spark.shuffle.api.ShufflePartitionReader; -import org.apache.spark.storage.ShuffleBlockFetcherIterator; import org.apache.spark.util.ByteBufferInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.ByteArrayInputStream; -import java.io.DataInputStream; -import java.io.IOException; -import java.io.InputStream; +import java.io.*; import java.nio.ByteBuffer; import java.util.Arrays; -import static org.apache.hadoop.hive.ql.exec.MapredContext.close; - public class ExternalShufflePartitionReader implements ShufflePartitionReader { private static final Logger logger = @@ -52,26 +46,21 @@ public InputStream fetchPartition(int reduceId) { TransportClient client = null; try { client = clientFactory.createUnmanagedClient(hostName, port); - client.setClientId(String.format( - "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId)); + String requestID = String.format( + "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId); + client.setClientId(requestID); logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + ByteBuffer response = client.sendRpcSync(openMessage.toByteBuffer(), 60000); logger.info("response is: " + response.toString() + " " + response.array() + " " + response.hasArray()); if (response.hasArray()) { logger.info("response hashcode: " + Arrays.hashCode(response.array())); - ByteArrayInputStream responseStream = - new ByteArrayInputStream(response.array()); - logger.info(String.format( - "Stream info %d %d", - responseStream.available(), - responseStream.read())); // use heap buffer; no array is created; only the reference is used - return new DataInputStream(responseStream); + return new ByteArrayInputStream(response.array()); } - ByteBufferInputStream responseStream = - new ByteBufferInputStream(response); - return new DataInputStream(responseStream); + return new ByteBufferInputStream(response); + } catch (Exception e) { if (client != null) { client.close(); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 30930f855882..1c78f186225f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -54,7 +54,8 @@ public ExternalShufflePartitionWriter( public long commitAndGetTotalLength() { RpcResponseCallback callback = new RpcResponseCallback() { @Override - public void onSuccess(ByteBuffer response) { logger.info("Successfully uploaded partition"); + public void onSuccess(ByteBuffer response) { + logger.info("Successfully uploaded partition"); } @Override diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7632c35f0318..0974c9139274 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -53,9 +53,11 @@ private[spark] class BlockStoreShuffleReader[K, C]( appId, handle.shuffleId, mapId) blockIds.map { case blockId@ShuffleBlockId(_, _, reduceId) => - (blockId, reader.fetchPartition(reduceId)) + (blockId, serializerManager.wrapStream(blockId, + reader.fetchPartition(reduceId))) case dataBlockId@ShuffleDataBlockId(_, _, reduceId) => - (dataBlockId, reader.fetchPartition(reduceId)) + (dataBlockId, serializerManager.wrapStream(dataBlockId, + reader.fetchPartition(reduceId))) case invalid => throw new IllegalArgumentException(s"Invalid block id $invalid") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index 5f943f15c073..e1eac0755883 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -31,7 +31,6 @@ import org.apache.spark.deploy.k8s.Config.KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEA import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.FileSegmentManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithID, TransportClient} -import org.apache.spark.network.server.OneForOneStreamManager import org.apache.spark.network.shuffle._ import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.network.util.{JavaUtils, TransportConf} From c91574dda4f6590741a19c182179a8b035c28170 Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Wed, 9 Jan 2019 16:24:24 -0800 Subject: [PATCH 14/30] IT WORKSSSSSSSS --- .../spark/shuffle/external/ExternalShuffleIndexWriter.java | 3 +-- .../spark/shuffle/external/ExternalShuffleMapOutputWriter.java | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java index 946060785596..fece52b05fce 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java @@ -53,8 +53,7 @@ public void onFailure(Throwable e) { TransportClient client = null; try { logger.info("Committing all partitions with a creation of an index file"); - logger.info("Partition Lengths: " + partitionLengths.length + ": " - + partitionLengths[0] + "," + partitionLengths[1]); + logger.info("Partition Lengths: " + partitionLengths.length); ByteBuffer streamHeader = new UploadShuffleIndexStream( appId, shuffleId, mapId).toByteBuffer(); // Size includes first 0L offset diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 786a56d46482..58c917bdffdb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -62,7 +62,7 @@ public void commitAllPartitions(long[] partitionLengths) { @Override public void abort(Exception exception) { clientFactory.close(); - logger.error("Encountered error while" + + logger.error("Encountered error while " + "attempting to add partitions to ESS", exception); } } From 7f1b215e16f3e598500d4c74dfe5192168e7348e Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 15 Jan 2019 10:35:13 -0800 Subject: [PATCH 15/30] scratch --- core/pom.xml | 5 ++ .../shuffle/api/ShuffleMapOutputWriter.java | 6 ++ .../external/ExternalShuffleDataIO.java | 7 ++- .../external/ExternalShuffleLocation.java | 31 ++++++++++ .../ExternalShuffleMapOutputWriter.java | 8 +++ .../external/ExternalShuffleReadSupport.java | 6 +- .../sort/BypassMergeSortShuffleWriter.java | 7 ++- .../shuffle/sort/UnsafeShuffleWriter.java | 6 +- .../apache/spark/scheduler/MapStatus.scala | 60 +++++++++++++++---- .../shuffle/sort/SortShuffleWriter.scala | 4 +- .../apache/spark/storage/BlockManager.scala | 5 +- .../spark/storage/ShuffleLocation.scala | 23 +++++++ .../apache/spark/MapOutputTrackerSuite.scala | 26 ++++---- .../serializer/KryoSerializerSuite.scala | 3 +- 14 files changed, 160 insertions(+), 37 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java create mode 100644 core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala diff --git a/core/pom.xml b/core/pom.xml index 49b1a54e3259..544ae61279c4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -352,6 +352,11 @@ py4j 0.10.8.1 + + org.scala-lang.modules + scala-java8-compat_${scala.binary.version} + 0.9.0 + org.apache.spark spark-tags_${scala.binary.version} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index f0f7d5ade602..60022cbbdada 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -17,11 +17,17 @@ package org.apache.spark.shuffle.api; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + public interface ShuffleMapOutputWriter { ShufflePartitionWriter newPartitionWriter(int partitionId); void commitAllPartitions(long[] partitionLengths); + Optional getShuffleLocation(); + void abort(Exception exception); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index da35ac76f343..0eae7e7958a5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -1,5 +1,6 @@ package org.apache.spark.shuffle.external; +import org.apache.spark.MapOutputTracker; import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; import org.apache.spark.network.netty.SparkTransportConf; @@ -18,20 +19,20 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private static final SparkEnv sparkEnv = SparkEnv.get(); private static final BlockManager blockManager = sparkEnv.blockManager(); - private final SparkConf sparkConf; private final TransportConf conf; private final SecurityManager securityManager; private final String hostname; private final int port; + private final MapOutputTracker mapOutputTracker; public ExternalShuffleDataIO( SparkConf sparkConf) { - this.sparkConf = sparkConf; this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); this.securityManager = sparkEnv.securityManager(); this.hostname = blockManager.getRandomShuffleHost(); this.port = blockManager.getRandomShufflePort(); + this.mapOutputTracker = sparkEnv.mapOutputTracker(); } @Override @@ -43,7 +44,7 @@ public void initialize() { public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + securityManager, hostname, port, mapOutputTracker); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java new file mode 100644 index 000000000000..1cdf512ab657 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -0,0 +1,31 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.network.protocol.Encoders; +import org.apache.spark.storage.ShuffleLocation; + +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; + +public class ExternalShuffleLocation implements ShuffleLocation { + + private String shuffleHostname; + private int shufflePort; + + public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { + this.shuffleHostname = shuffleHostname; + this.shufflePort = shufflePort; + } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + out.writeObject(shuffleHostname); + out.writeInt(shufflePort); + } + + @Override + public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { + this.shuffleHostname = (String) in.readObject(); + this.shufflePort = in.readInt(); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 58c917bdffdb..fa4525614caf 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -3,9 +3,12 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.Optional; + public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final TransportClientFactory clientFactory; @@ -59,6 +62,11 @@ public void commitAllPartitions(long[] partitionLengths) { } } + @Override + public Optional getShuffleLocation() { + return Optional.of(new ExternalShuffleLocation(hostName, port)); + } + @Override public void abort(Exception exception) { clientFactory.close(); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index ddff937d47c2..d20c790e2e20 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -1,6 +1,7 @@ package org.apache.spark.shuffle.external; import com.google.common.collect.Lists; +import org.apache.spark.MapOutputTracker; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; @@ -24,18 +25,21 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private final SecretKeyHolder secretKeyHolder; private final String hostName; private final int port; + private final MapOutputTracker mapOutputTracker; public ExternalShuffleReadSupport( TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, String hostName, - int port) { + int port, + MapOutputTracker mapOutputTracker) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostName = hostName; this.port = port; + this.mapOutputTracker = mapOutputTracker; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 2cdf0c4600ae..1004bfe8d7e6 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -50,6 +50,7 @@ import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; +import scala.compat.java8.OptionConverters; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -95,6 +96,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; private long[] partitionLengths; + private Option shuffleLocation = Option.empty(); /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -133,7 +135,7 @@ public void write(Iterator> records) throws IOException { if (!records.hasNext()) { partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -179,7 +181,7 @@ public void write(Iterator> records) throws IOException { } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); } @VisibleForTesting @@ -268,6 +270,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio } } mapOutputWriter.commitAllPartitions(lengths); + shuffleLocation = OptionConverters.toScala(mapOutputWriter.getShuffleLocation()); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4e299034a893..3a51aa3ac4b1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -22,9 +22,11 @@ import java.nio.channels.FileChannel; import java.util.Iterator; +import org.apache.spark.storage.ShuffleLocation; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; +import scala.compat.java8.OptionConverters; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -89,6 +91,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; + private Option shuffleLocation = Option.empty(); private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -257,7 +260,7 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); } @VisibleForTesting @@ -564,6 +567,7 @@ private long[] mergeSpillsWithPluggableWriter( } } mapOutputWriter.commitAllPartitions(partitionLengths); + shuffleLocation = OptionConverters.toScala(mapOutputWriter.getShuffleLocation()); threwException = false; } catch (Exception e) { try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 64f0a060a247..5067b46ba8f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -25,7 +25,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.SparkEnv import org.apache.spark.internal.config -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils /** @@ -36,6 +36,8 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId + def shuffleLocation: Option[ShuffleLocation] + /** * Estimated size for the reduce block, in bytes. * @@ -56,11 +58,12 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocation: Option[ShuffleLocation]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes) + HighlyCompressedMapStatus(loc, uncompressedSizes, shuffleLocation) } else { - new CompressedMapStatus(loc, uncompressedSizes) + new CompressedMapStatus(loc, uncompressedSizes, shuffleLocation) } } @@ -103,17 +106,22 @@ private[spark] object MapStatus { */ private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, - private[this] var compressedSizes: Array[Byte]) + private[this] var compressedSizes: Array[Byte], + private[this] var shuffleLoc: Option[ShuffleLocation]) extends MapStatus with Externalizable { - protected def this() = this(null, null.asInstanceOf[Array[Byte]]) // For deserialization only + // For deserialization only + protected def this() = this(null, null.asInstanceOf[Array[Byte]], null) - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize)) + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLoc: Option[ShuffleLocation]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLoc) } override def location: BlockManagerId = loc + override def shuffleLocation: Option[ShuffleLocation] = shuffleLoc + override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) } @@ -122,6 +130,12 @@ private[spark] class CompressedMapStatus( loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) + if (shuffleLoc.isDefined) { + out.writeBoolean(true) + shuffleLoc.get.writeExternal(out) + } else { + out.writeBoolean(false) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -129,6 +143,12 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) + val shuffleLocationExists = in.readBoolean() + if (shuffleLocationExists) { + shuffleLoc = Option.apply(in.readObject().asInstanceOf[ShuffleLocation]) + } else { + shuffleLoc = Option.empty + } } } @@ -148,17 +168,20 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var numNonEmptyBlocks: Int, private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, - private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte]) + private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], + private[this] var shuffleLoc: Option[ShuffleLocation]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization require(loc == null || avgSize > 0 || hugeBlockSizes.size > 0 || numNonEmptyBlocks == 0, "Average size can only be zero for map stages that produced no output") - protected def this() = this(null, -1, null, -1, null) // For deserialization only + protected def this() = this(null, -1, null, -1, null, null) // For deserialization only override def location: BlockManagerId = loc + override def shuffleLocation: Option[ShuffleLocation] = shuffleLoc + override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) if (emptyBlocks.contains(reduceId)) { @@ -180,6 +203,12 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } + if (shuffleLoc.isDefined) { + out.writeBoolean(true) + shuffleLoc.get.writeExternal(out) + } else { + out.writeBoolean(false) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -195,11 +224,18 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl + val shuffleLocationExists = in.readBoolean() + if (shuffleLocationExists) { + shuffleLoc = Option.apply(in.readObject().asInstanceOf[ShuffleLocation]) + } else { + shuffleLoc = Option.empty + } } } private[spark] object HighlyCompressedMapStatus { - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], + shuffleLocation: Option[ShuffleLocation]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -240,6 +276,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes) + hugeBlockSizes, shuffleLocation) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 1c804c99d0e3..33736d56706d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -49,6 +49,8 @@ private[spark] class SortShuffleWriter[K, V, C]( private val writeMetrics = context.taskMetrics().shuffleWriteMetrics + private val shuffleLocation = Option.empty + /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -76,7 +78,7 @@ private[spark] class SortShuffleWriter[K, V, C]( if (pluggableWriteSupport.isEmpty) { shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) } - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, shuffleLocation) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fb1ed02c857a..1575b076d3fa 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import java.io._ -import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} +import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections @@ -31,12 +31,11 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal - import com.codahale.metrics.{MetricRegistry, MetricSet} import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{Logging, config} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala new file mode 100644 index 000000000000..72846cb001c8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleLocation.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.storage + +import java.io.Externalizable + +trait ShuffleLocation extends Externalizable { + +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 21f481d47724..9506c86cdd5d 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,9 +62,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L))) + Array(1000L, 10000L), Option.empty)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L))) + Array(10000L, 1000L), Option.empty)) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +84,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array(compressedSize1000, compressedSize10000), Option.empty)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000), Option.empty)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +107,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array(compressedSize1000, compressedSize1000, compressedSize1000), Option.empty)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array(compressedSize10000, compressedSize1000, compressedSize1000), Option.empty)) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) + BlockManagerId("a", "hostA", 1000), Array(1000L), Option.empty)) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +182,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), Option.empty)) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +216,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), Option.empty)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L))) + Array(2L), Option.empty)) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L))) + Array(3L), Option.empty)) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -309,9 +309,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000))) + Array(size0, size1000, size0, size10000), Option.empty)) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0))) + Array(size10000, size0, size1000, size0), Option.empty)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 467e49026a02..6e89ab710206 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -349,7 +349,8 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus( + BlockManagerId("exec-1", "host", 1234), blockSizes, Option.empty)) } } From d0c8f29033c798300e8184601679023224a709be Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 15 Jan 2019 11:57:42 -0800 Subject: [PATCH 16/30] attempt 1 --- .../external/ExternalShuffleLocation.java | 8 ++++++++ .../external/ExternalShuffleReadSupport.java | 12 +++++++++++- .../org/apache/spark/MapOutputTracker.scala | 18 +++++++++++++++++- .../org/apache/spark/scheduler/MapStatus.scala | 12 ++++++++++++ .../shuffle/sort/UnsafeShuffleWriterSuite.java | 5 +++++ .../org/apache/spark/SplitFilesShuffleIO.scala | 5 +++++ .../scheduler/cluster/YarnClusterManager.scala | 2 +- 7 files changed, 59 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java index 1cdf512ab657..9c1b8db027a0 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -28,4 +28,12 @@ public void readExternal(ObjectInput in) throws IOException, ClassNotFoundExcept this.shuffleHostname = (String) in.readObject(); this.shufflePort = in.readInt(); } + + public String getShuffleHostname() { + return this.shuffleHostname; + } + + public int getShufflePort() { + return this.shufflePort; + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index d20c790e2e20..b98eb3b10916 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -11,10 +11,13 @@ import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.shuffle.api.ShuffleReadSupport; +import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.compat.java8.OptionConverters; import java.util.List; +import java.util.Optional; public class ExternalShuffleReadSupport implements ShuffleReadSupport { @@ -51,9 +54,16 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId)); + assert maybeShuffleLocation.isPresent(); + ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get(); try { return new ExternalShufflePartitionReader(clientFactory, - hostName, port, appId, shuffleId, mapId); + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort(), + appId, + shuffleId, + mapId); } catch (Exception e) { clientFactory.close(); logger.error("Encountered creating transport client for partition reader"); diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fb587f02256e..43696f1e5fa5 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -35,7 +35,7 @@ import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle._ -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleLocation} import org.apache.spark.util._ /** @@ -303,6 +303,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] + def getShuffleLocation(shuffleId: Int, mapId: Int) : Option[ShuffleLocation] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -676,6 +678,13 @@ private[spark] class MapOutputTrackerMaster( trackerEndpoint = null shuffleStatuses.clear() } + + override def getShuffleLocation(shuffleId: Int, mapId: Int): Option[ShuffleLocation] = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocation + case None => Option.empty + } + } } /** @@ -789,6 +798,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } } + + override def getShuffleLocation(shuffleId: Int, mapId: Int): Option[ShuffleLocation] = { + mapStatuses.get(shuffleId) match { + case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocation + case None => Option.empty + } + } } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 5067b46ba8f8..6b262c27fa72 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -67,6 +67,14 @@ private[spark] object MapStatus { } } + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { + if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, uncompressedSizes, Option.empty) + } else { + new CompressedMapStatus(loc, uncompressedSizes, Option.empty) + } + } + private[this] val LOG_BASE = 1.1 /** @@ -118,6 +126,10 @@ private[spark] class CompressedMapStatus( this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLoc) } + def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), Option.empty) + } + override def location: BlockManagerId = loc override def shuffleLocation: Option[ShuffleLocation] = shuffleLoc diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0b18aceef92d..68293272fde6 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -700,6 +700,11 @@ public void commitAllPartitions(long[] partitionlegnths) { } + @Override + public Optional getShuffleLocation() { + return Optional.empty(); + } + @Override public void abort(Exception failureReason) { diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index f6ac1fcc05a1..e1396a74acc4 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -19,9 +19,12 @@ package org.apache.spark import java.io._ import java.nio.file.Paths +import java.util.Optional + import javax.ws.rs.core.UriBuilder import org.apache.spark.shuffle.api._ +import org.apache.spark.storage.ShuffleLocation import org.apache.spark.util.Utils class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { @@ -59,6 +62,8 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { override def commitAllPartitions(partitionLengths: Array[Long]): Unit = {} override def abort(exception: Exception): Unit = {} + + override def getShuffleLocation: Optional[ShuffleLocation] = Optional.empty() } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index f3c9e3e2741f..b2a4fd42c60f 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -54,6 +54,6 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } - override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = DefaultShuffleServiceAddressProvider } From c2231a0ac7b5104464813f8a3b1576db876bbe1e Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 15 Jan 2019 17:26:41 -0800 Subject: [PATCH 17/30] resolving a few of the initial comments while still preserving correctness of e2e tests --- .../shuffle/ExternalShuffleBlockResolver.java | 15 +-- .../protocol/BlockTransferMessage.java | 5 +- .../protocol/RegisterShuffleIndex.java | 92 +++++++++++++++ ...dexStream.java => UploadShuffleIndex.java} | 10 +- .../UploadShufflePartitionStream.java | 19 ++- .../shuffle/api/ShuffleMapOutputWriter.java | 2 +- .../external/ExternalShuffleDataIO.java | 39 +++---- .../external/ExternalShuffleIndexWriter.java | 78 ------------- .../ExternalShuffleMapOutputWriter.java | 42 +++++-- .../ExternalShufflePartitionWriter.java | 4 +- .../external/ExternalShuffleReadSupport.java | 5 +- .../external/ExternalShuffleWriteSupport.java | 58 +++++----- .../sort/BypassMergeSortShuffleWriter.java | 2 +- .../shuffle/sort/UnsafeShuffleWriter.java | 2 +- .../scala/org/apache/spark/SparkContext.scala | 4 + .../scala/org/apache/spark/SparkEnv.scala | 8 +- .../org/apache/spark/executor/Executor.scala | 2 + .../shuffle/sort/SortShuffleManager.scala | 14 +-- .../util/collection/ExternalSorter.scala | 2 +- .../sort/UnsafeShuffleWriterSuite.java | 2 +- .../apache/spark/SplitFilesShuffleIO.scala | 2 +- .../spark/examples/GroupByShuffleTest.scala | 4 +- .../KubernetesExternalShuffleService.scala | 109 +++++++++++++++--- 23 files changed, 326 insertions(+), 194 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java rename common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/{UploadShuffleIndexStream.java => UploadShuffleIndex.java} (90%) delete mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 6b1d879d0618..757b8f7b545b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -70,9 +70,6 @@ public class ExternalShuffleBlockResolver { private static final String APP_KEY_PREFIX = "AppExecShuffleInfo"; private static final StoreVersion CURRENT_VERSION = new StoreVersion(1, 0); - // TODO: Dont necessarily write to local - private final File shuffleDir; - private static final Pattern MULTIPLE_SEPARATORS = Pattern.compile(File.separator + "{2,}"); // Map containing all registered executors' metadata. @@ -96,8 +93,8 @@ public class ExternalShuffleBlockResolver { final DB db; private final List knownManagers = Arrays.asList( - "org.apache.spark.shuffle.sort.SortShuffleManager", - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + "org.apache.spark.shuffle.sort.SortShuffleManager", + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { @@ -136,9 +133,6 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { executors = Maps.newConcurrentMap(); } - // TODO: Remove local writes - this.shuffleDir = Files.createTempDirectory("spark-shuffle-dir").toFile(); - this.directoryCleaner = directoryCleaner; } @@ -146,7 +140,6 @@ public int getRegisteredExecutorsSize() { return executors.size(); } - /** Registers a new Executor with all the configuration we need to find its shuffle files. */ public void registerExecutor( String appId, @@ -313,8 +306,8 @@ private ManagedBuffer getSortBasedShuffleBlockData( * Hashes a filename into the corresponding local directory, in a manner consistent with * Spark's DiskBlockManager.getFile(). */ - - public static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { + @VisibleForTesting + static File getFile(String[] localDirs, int subDirsPerLocalDir, String filename) { int hash = JavaUtils.nonNegativeHash(filename); String localDir = localDirs[hash % localDirs.length]; int subDirId = (hash / localDirs.length) % subDirsPerLocalDir; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index b2b0f3f9796c..f5196638f914 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -41,7 +41,7 @@ public abstract class BlockTransferMessage implements Encodable { public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), UPLOAD_SHUFFLE_PARTITION_STREAM(7), - UPLOAD_SHUFFLE_INDEX_STREAM(8), OPEN_SHUFFLE_PARTITION(9); + REGISTER_SHUFFLE_INDEX(8), OPEN_SHUFFLE_PARTITION(9), UPLOAD_SHUFFLE_INDEX(10); private final byte id; @@ -68,8 +68,9 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 5: return ShuffleServiceHeartbeat.decode(buf); case 6: return UploadBlockStream.decode(buf); case 7: return UploadShufflePartitionStream.decode(buf); - case 8: return UploadShuffleIndexStream.decode(buf); + case 8: return RegisterShuffleIndex.decode(buf); case 9: return OpenShufflePartition.decode(buf); + case 10: return UploadShuffleIndex.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java new file mode 100644 index 000000000000..27f101171834 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** + * Register shuffle index to the External Shuffle Service. + */ +public class RegisterShuffleIndex extends BlockTransferMessage { + public final String appId; + public final int shuffleId; + public final int mapId; + + public RegisterShuffleIndex( + String appId, + int shuffleId, + int mapId) { + this.appId = appId; + this.shuffleId = shuffleId; + this.mapId = mapId; + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof UploadShufflePartitionStream) { + UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + return Objects.equal(appId, o.appId) + && shuffleId == o.shuffleId + && mapId == o.mapId; + } + return false; + } + + @Override + protected Type type() { + return Type.REGISTER_SHUFFLE_INDEX; + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId, mapId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .add("mapId", mapId) + .toString(); + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + 4 + 4; + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + buf.writeInt(shuffleId); + buf.writeInt(mapId); + } + + public static RegisterShuffleIndex decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapId = buf.readInt(); + return new RegisterShuffleIndex(appId, shuffleId, mapId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java similarity index 90% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java index ffa7ee36881c..374b399621aa 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndexStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java @@ -27,12 +27,12 @@ /** * Upload shuffle index request to the External Shuffle Service. */ -public class UploadShuffleIndexStream extends BlockTransferMessage { +public class UploadShuffleIndex extends BlockTransferMessage { public final String appId; public final int shuffleId; public final int mapId; - public UploadShuffleIndexStream( + public UploadShuffleIndex( String appId, int shuffleId, int mapId) { @@ -54,7 +54,7 @@ public boolean equals(Object other) { @Override protected Type type() { - return Type.UPLOAD_SHUFFLE_INDEX_STREAM; + return Type.UPLOAD_SHUFFLE_INDEX; } @Override @@ -83,10 +83,10 @@ public void encode(ByteBuf buf) { buf.writeInt(mapId); } - public static UploadShuffleIndexStream decode(ByteBuf buf) { + public static UploadShuffleIndex decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); int shuffleId = buf.readInt(); int mapId = buf.readInt(); - return new UploadShuffleIndexStream(appId, shuffleId, mapId); + return new UploadShuffleIndex(appId, shuffleId, mapId); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java index f0506cc08feb..ad8f5405192f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShufflePartitionStream.java @@ -32,16 +32,19 @@ public class UploadShufflePartitionStream extends BlockTransferMessage { public final int shuffleId; public final int mapId; public final int partitionId; + public final int partitionLength; public UploadShufflePartitionStream( String appId, int shuffleId, int mapId, - int partitionId) { + int partitionId, + int partitionLength) { this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; + this.partitionLength = partitionLength; } @Override @@ -51,7 +54,8 @@ public boolean equals(Object other) { return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId - && partitionId == o.partitionId; + && partitionId == o.partitionId + && partitionLength == o.partitionLength; } return false; } @@ -63,7 +67,7 @@ protected Type type() { @Override public int hashCode() { - return Objects.hashCode(appId, shuffleId, mapId, partitionId); + return Objects.hashCode(appId, shuffleId, mapId, partitionId, partitionLength); } @Override @@ -72,12 +76,14 @@ public String toString() { .add("appId", appId) .add("shuffleId", shuffleId) .add("mapId", mapId) + .add("partitionId", partitionId) + .add("partitionLength", partitionLength) .toString(); } @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4; + return Encoders.Strings.encodedLength(appId) + 4 + 4 + 4 + 4; } @Override @@ -86,6 +92,7 @@ public void encode(ByteBuf buf) { buf.writeInt(shuffleId); buf.writeInt(mapId); buf.writeInt(partitionId); + buf.writeInt(partitionLength); } public static UploadShufflePartitionStream decode(ByteBuf buf) { @@ -93,6 +100,8 @@ public static UploadShufflePartitionStream decode(ByteBuf buf) { int shuffleId = buf.readInt(); int mapId = buf.readInt(); int partitionId = buf.readInt(); - return new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId); + int partitionLength = buf.readInt(); + return new UploadShufflePartitionStream( + appId, shuffleId, mapId, partitionId, partitionLength); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index f0f7d5ade602..06415dba72d3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -21,7 +21,7 @@ public interface ShuffleMapOutputWriter { ShufflePartitionWriter newPartitionWriter(int partitionId); - void commitAllPartitions(long[] partitionLengths); + void commitAllPartitions(); void abort(Exception exception); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index da35ac76f343..22a1d3336615 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -2,7 +2,9 @@ import org.apache.spark.SparkConf; import org.apache.spark.SparkEnv; +import org.apache.spark.network.TransportContext; import org.apache.spark.network.netty.SparkTransportConf; +import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleDataIO; import org.apache.spark.shuffle.api.ShuffleReadSupport; @@ -12,44 +14,41 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { - private static final String SHUFFLE_SERVICE_PORT_CONFIG = "spark.shuffle.service.port"; - private static final String DEFAULT_SHUFFLE_PORT = "7337"; - - private static final SparkEnv sparkEnv = SparkEnv.get(); - private static final BlockManager blockManager = sparkEnv.blockManager(); - - private final SparkConf sparkConf; private final TransportConf conf; - private final SecurityManager securityManager; - private final String hostname; - private final int port; + private final TransportContext context; + private static BlockManager blockManager; + private static SecurityManager securityManager; + private static String hostname; + private static int port; public ExternalShuffleDataIO( SparkConf sparkConf) { - this.sparkConf = sparkConf; this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); - - this.securityManager = sparkEnv.securityManager(); - this.hostname = blockManager.getRandomShuffleHost(); - this.port = blockManager.getRandomShufflePort(); + // Close idle connections + this.context = new TransportContext(conf, new NoOpRpcHandler(), true, true); } @Override public void initialize() { - // TODO: move registerDriver and registerExecutor here + SparkEnv env = SparkEnv.get(); + blockManager = env.blockManager(); + securityManager = env.securityManager(); + hostname = blockManager.getRandomShuffleHost(); + port = blockManager.getRandomShufflePort(); + // TODO: Register Driver and Executor } @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + conf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port); } @Override public ShuffleWriteSupport writeSupport() { return new ExternalShuffleWriteSupport( - conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + conf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostname, port); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java deleted file mode 100644 index fece52b05fce..000000000000 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleIndexWriter.java +++ /dev/null @@ -1,78 +0,0 @@ -package org.apache.spark.shuffle.external; - -import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; -import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.client.TransportClientFactory; -import org.apache.spark.network.shuffle.protocol.UploadShuffleIndexStream; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.nio.ByteBuffer; -import java.nio.LongBuffer; - -public class ExternalShuffleIndexWriter { - - private final TransportClientFactory clientFactory; - private final String hostName; - private final int port; - private final String appId; - private final int shuffleId; - private final int mapId; - - public ExternalShuffleIndexWriter( - TransportClientFactory clientFactory, - String hostName, - int port, - String appId, - int shuffleId, - int mapId){ - this.clientFactory = clientFactory; - this.hostName = hostName; - this.port = port; - this.appId = appId; - this.shuffleId = shuffleId; - this.mapId = mapId; - } - - private static final Logger logger = - LoggerFactory.getLogger(ExternalShuffleIndexWriter.class); - - public void write(long[] partitionLengths) { - RpcResponseCallback callback = new RpcResponseCallback() { - @Override - public void onSuccess(ByteBuffer response) { - logger.info("Successfully uploaded index"); - } - - @Override - public void onFailure(Throwable e) { - logger.error("Encountered an error uploading index", e); - } - }; - TransportClient client = null; - try { - logger.info("Committing all partitions with a creation of an index file"); - logger.info("Partition Lengths: " + partitionLengths.length); - ByteBuffer streamHeader = new UploadShuffleIndexStream( - appId, shuffleId, mapId).toByteBuffer(); - // Size includes first 0L offset - ByteBuffer byteBuffer = ByteBuffer.allocate(8 + (partitionLengths.length * 8)); - LongBuffer longBuffer = byteBuffer.asLongBuffer(); - Long offset = 0L; - longBuffer.put(offset); - for (Long length: partitionLengths) { - offset += length; - longBuffer.put(offset); - } - client = clientFactory.createUnmanagedClient(hostName, port); - client.setClientId(String.format("index-%s-%d-%d", appId, shuffleId, mapId)); - logger.info("clientid: " + client.getClientId() + " " + client.isActive()); - client.uploadStream(new NioManagedBuffer(streamHeader), - new NioManagedBuffer(byteBuffer), callback); - } catch (Exception e) { - client.close(); - logger.error("Encountered error while creating transport client", e); - } - } -} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 58c917bdffdb..34a11ce2b2a3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -1,11 +1,16 @@ package org.apache.spark.shuffle.external; +import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; +import org.apache.spark.network.shuffle.protocol.RegisterShuffleIndex; +import org.apache.spark.network.shuffle.protocol.UploadShuffleIndex; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.nio.ByteBuffer; + public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { private final TransportClientFactory clientFactory; @@ -28,6 +33,22 @@ public ExternalShuffleMapOutputWriter( this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; + + TransportClient client = null; + try { + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer registerShuffleIndex = new RegisterShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-register-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(registerShuffleIndex, 60000); + } catch (Exception e) { + client.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); + } } private static final Logger logger = @@ -46,16 +67,21 @@ public ShufflePartitionWriter newPartitionWriter(int partitionId) { } @Override - public void commitAllPartitions(long[] partitionLengths) { + public void commitAllPartitions() { + TransportClient client = null; try { - ExternalShuffleIndexWriter externalShuffleIndexWriter = - new ExternalShuffleIndexWriter(clientFactory, - hostName, port, appId, shuffleId, mapId); - externalShuffleIndexWriter.write(partitionLengths); + client = clientFactory.createUnmanagedClient(hostName, port); + ByteBuffer uploadShuffleIndex = new UploadShuffleIndex( + appId, shuffleId, mapId).toByteBuffer(); + String requestID = String.format( + "index-upload-%s-%d-%d", appId, shuffleId, mapId); + client.setClientId(requestID); + logger.info("clientid: " + client.getClientId() + " " + client.isActive()); + client.sendRpcSync(uploadShuffleIndex, 60000); } catch (Exception e) { - clientFactory.close(); - logger.error("Encountered error writing index file", e); - throw new RuntimeException(e); // what is standard practice here? + client.close(); + logger.error("Encountered error while creating transport client", e); + throw new RuntimeException(e); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 1c78f186225f..d9b7d7ac515d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -65,10 +65,10 @@ public void onFailure(Throwable e) { }; TransportClient client = null; try { - ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, - partitionId).toByteBuffer(); byte[] buf = partitionBuffer.toByteArray(); int size = buf.length; + ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, + partitionId, size).toByteBuffer(); ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); client = clientFactory.createUnmanagedClient(hostName, port); client.setClientId(String.format("data-%s-%d-%d-%d", diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index ddff937d47c2..2687c2a4e237 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -6,7 +6,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.shuffle.api.ShuffleReadSupport; @@ -20,6 +19,7 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleReadSupport.class); private final TransportConf conf; + private final TransportContext context; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; private final String hostName; @@ -27,11 +27,13 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { public ExternalShuffleReadSupport( TransportConf conf, + TransportContext context, boolean authEnabled, SecretKeyHolder secretKeyHolder, String hostName, int port) { this.conf = conf; + this.context = context; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; this.hostName = hostName; @@ -41,7 +43,6 @@ public ExternalShuffleReadSupport( @Override public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) { // TODO combine this into a function with ExternalShuffleWriteSupport - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); List bootstraps = Lists.newArrayList(); if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java index 4754c58f136b..413c2fd63f20 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleWriteSupport.java @@ -6,7 +6,6 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.crypto.AuthClientBootstrap; import org.apache.spark.network.sasl.SecretKeyHolder; -import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; @@ -17,34 +16,39 @@ public class ExternalShuffleWriteSupport implements ShuffleWriteSupport { - private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); + private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleWriteSupport.class); - private final TransportConf conf; - private final boolean authEnabled; - private final SecretKeyHolder secretKeyHolder; - private final String hostname; - private final int port; + private final TransportConf conf; + private final TransportContext context; + private final boolean authEnabled; + private final SecretKeyHolder secretKeyHolder; + private final String hostname; + private final int port; - public ExternalShuffleWriteSupport( - TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostname, int port) { - this.conf = conf; - this.authEnabled = authEnabled; - this.secretKeyHolder = secretKeyHolder; - this.hostname = hostname; - this.port = port; - } + public ExternalShuffleWriteSupport( + TransportConf conf, + TransportContext context, + boolean authEnabled, + SecretKeyHolder secretKeyHolder, + String hostname, + int port) { + this.conf = conf; + this.context = context; + this.authEnabled = authEnabled; + this.secretKeyHolder = secretKeyHolder; + this.hostname = hostname; + this.port = port; +} - @Override - public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { - TransportContext context = new TransportContext(conf, new NoOpRpcHandler(), false); - List bootstraps = Lists.newArrayList(); - if (authEnabled) { - bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); - } - TransportClientFactory clientFactory = context.createClientFactory(bootstraps); - logger.info("Clientfactory: " + clientFactory.toString()); - return new ExternalShuffleMapOutputWriter( - clientFactory, hostname, port, appId, shuffleId, mapId); + @Override + public ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) { + List bootstraps = Lists.newArrayList(); + if (authEnabled) { + bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); + logger.info("Clientfactory: " + clientFactory.toString()); + return new ExternalShuffleMapOutputWriter( + clientFactory, hostname, port, appId, shuffleId, mapId); + } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 2cdf0c4600ae..823c36d051dd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -267,7 +267,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio } } } - mapOutputWriter.commitAllPartitions(lengths); + mapOutputWriter.commitAllPartitions(); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 4e299034a893..32be62009511 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -563,7 +563,7 @@ private long[] mergeSpillsWithPluggableWriter( throw e; } } - mapOutputWriter.commitAllPartitions(partitionLengths); + mapOutputWriter.commitAllPartitions(); threwException = false; } catch (Exception e) { try { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 845a3d5f6d6f..247016584d1f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -576,6 +576,10 @@ class SparkContext(config: SparkConf) extends Logging { _env.metricsSystem.registerSource(e.executorAllocationManagerSource) } appStatusSource.foreach(_env.metricsSystem.registerSource(_)) + + // Initialize the ShuffleDataIo + _env.shuffleDataIO.foreach(_.initialize()) + // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 45aabc05f49b..c2b56864bf36 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.io.File import java.net.Socket -import java.util.{Locale, ServiceLoader} +import java.util.Locale import com.google.common.collect.MapMaker import scala.collection.mutable @@ -39,6 +39,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.security.CryptoStreamUtils import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} import org.apache.spark.shuffle.{DefaultShuffleServiceAddressProvider, ShuffleManager, ShuffleServiceAddressProviderFactory} +import org.apache.spark.shuffle.api.ShuffleDataIO import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} @@ -65,6 +66,7 @@ class SparkEnv ( val blockManager: BlockManager, val securityManager: SecurityManager, val metricsSystem: MetricsSystem, + val shuffleDataIO: Option[ShuffleDataIO], val memoryManager: MemoryManager, val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { @@ -383,6 +385,9 @@ object SparkEnv extends Logging { ms } + val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) + .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) + val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf, isDriver) } @@ -402,6 +407,7 @@ object SparkEnv extends Logging { blockManager, securityManager, metricsSystem, + shuffleIoPlugin, memoryManager, outputCommitCoordinator, conf) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index a30a501e5d4a..ae5b1a3c6946 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -118,6 +118,8 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) env.metricsSystem.registerSource(executorSource) env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) + // Initialize the ShuffleDataIo + env.shuffleDataIO.foreach(_.initialize()) } // Whether to load classes in user jars before those in Spark jars diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index ba56da9089a7..eb7ae313918e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -119,9 +119,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { - val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - shuffleIoPlugin.foreach(_.initialize()) new BlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], conf.getAppId, @@ -129,7 +126,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition, context, metrics, - shuffleIoPlugin.map(_.readSupport())) + SparkEnv.get.shuffleDataIO.map(_.readSupport())) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -141,9 +138,6 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager numMapsForShuffle.putIfAbsent( handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps) val env = SparkEnv.get - val shuffleIoPlugin = conf.get(SHUFFLE_IO_PLUGIN_CLASS) - .map(clazz => Utils.loadExtensions(classOf[ShuffleDataIO], Seq(clazz), conf).head) - shuffleIoPlugin.foreach(_.initialize()) handle match { case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( @@ -155,7 +149,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager context, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, @@ -164,10 +158,10 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager mapId, env.conf, metrics, - shuffleIoPlugin.map(_.writeSupport()).orNull) + env.shuffleDataIO.map(_.writeSupport()).orNull) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter( - shuffleBlockResolver, other, mapId, context, shuffleIoPlugin.map(_.writeSupport())) + shuffleBlockResolver, other, mapId, context, env.shuffleDataIO.map(_.writeSupport())) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 01cc838474d8..569c8bd092f3 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -778,7 +778,7 @@ private[spark] class ExternalSorter[K, V, C]( } } } - mapOutputWriter.commitAllPartitions(lengths) + mapOutputWriter.commitAllPartitions() } catch { case e: Exception => util.Utils.tryLogNonFatalError { diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 0b18aceef92d..539336cd4fd8 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -696,7 +696,7 @@ public void abort(Exception failureReason) { } @Override - public void commitAllPartitions(long[] partitionlegnths) { + public void commitAllPartitions() { } diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index f6ac1fcc05a1..3a68fded945b 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -56,7 +56,7 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { } } - override def commitAllPartitions(partitionLengths: Array[Long]): Unit = {} + override def commitAllPartitions(): Unit = {} override def abort(exception: Exception): Unit = {} } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala index 0ce0d3bec6cd..883ac10718df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -18,12 +18,10 @@ // scalastyle:off println package org.apache.spark.examples -import java.util.Random - import org.apache.spark.sql.SparkSession /** - * Usage: GroupByShuffleTest [numMappers] [numKVPairs] [KeySize] [numReducers] + * Usage: GroupByShuffleTest */ object GroupByShuffleTest { def main(args: Array[String]) { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index e1eac0755883..b9d69f1bc69f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -17,13 +17,17 @@ package org.apache.spark.deploy.k8s -import java.io.File +import java.io.{DataOutputStream, File, FileOutputStream} import java.nio.ByteBuffer import java.nio.file.Paths +import java.util import java.util.concurrent.{ConcurrentHashMap, ExecutionException, TimeUnit} +import java.util.function.BiFunction +import com.codahale.metrics._ import com.google.common.cache.{CacheBuilder, CacheLoader, Weigher} import scala.collection.JavaConverters._ +import scala.collection.immutable.TreeMap import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.ExternalShuffleService @@ -46,7 +50,6 @@ private[spark] class KubernetesExternalShuffleBlockHandler( indexCacheSize: String) extends ExternalShuffleBlockHandler(transportConf, null) with Logging { - ThreadUtils.newDaemonSingleThreadScheduledExecutor("shuffle-cleaner-watcher") .scheduleAtFixedRate(new CleanerThread(), 0, cleanerIntervals, TimeUnit.SECONDS) @@ -63,14 +66,23 @@ private[spark] class KubernetesExternalShuffleBlockHandler( }) .build(indexCacheLoader) + // TODO: Investigate cleanup if appId is terminated + private val globalPartitionLengths = new ConcurrentHashMap[(String, Int, Int), TreeMap[Int, Long]] + private final val shuffleDir = Utils.createDirectory("/tmp", "spark-shuffle-dir") + private final val metricSet: RemoteShuffleMetrics = new RemoteShuffleMetrics() + + private def scanLeft[a, b](xs: Iterable[a])(s: b)(f: (b, a) => b) = + xs.foldLeft(List(s))( (acc, x) => f(acc.head, x) :: acc).reverse + protected override def handleMessage( message: BlockTransferMessage, client: TransportClient, callback: RpcResponseCallback): Unit = { message match { case RegisterDriverParam(appId, appState) => + val responseDelayContext = metricSet.registerDriverRequestLatencyMillis.time() val address = client.getSocketAddress val timeout = appState.heartbeatTimeout logInfo(s"Received registration request from app $appId (remote address $address, " + @@ -84,6 +96,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( throw new RuntimeException(s"Failed to create dir ${driverDir.getAbsolutePath}") } connectedApps.put(appId, appState) + responseDelayContext.stop() callback.onSuccess(ByteBuffer.allocate(0)) case Heartbeat(appId) => @@ -97,9 +110,34 @@ private[spark] class KubernetesExternalShuffleBlockHandler( logWarning(s"Received ShuffleServiceHeartbeat from an unknown app (remote " + s"address $address, appId '$appId').") } + + case RegisterIndexParam(appId, shuffleId, mapId) => + logInfo(s"Received register index param from app $appId") + globalPartitionLengths.putIfAbsent( + (appId, shuffleId, mapId), TreeMap.empty[Int, Long]) + callback.onSuccess(ByteBuffer.allocate(0)) + + case UploadIndexParam(appId, shuffleId, mapId) => + val responseDelayContext = metricSet.writeIndexRequestLatencyMillis.time() + try { + logInfo(s"Received upload index param from app $appId") + val partitionMap = globalPartitionLengths.get((appId, shuffleId, mapId)) + val out = new DataOutputStream( + new FileOutputStream(getFile(appId, shuffleId, mapId, "index"))) + scanLeft(partitionMap.values)(0L)(_ + _).foreach(l => out.writeLong(l)) + out.close() + callback.onSuccess(ByteBuffer.allocate(0)) + } finally { + responseDelayContext.stop() + } + case OpenParam(appId, shuffleId, mapId, partitionId) => logInfo(s"Received open param from app $appId") + val responseDelayContext = metricSet.openBlockRequestLatencyMillis.time() val indexFile = getFile(appId, shuffleId, mapId, "index") + logInfo(s"Map: " + + s"${globalPartitionLengths.get((appId, shuffleId, mapId)).toString()}" + + s"for partitionId: $partitionId") try { val shuffleIndexInformation = shuffleIndexCache.get(indexFile) val shuffleIndexRecord = shuffleIndexInformation.getIndex(partitionId) @@ -111,6 +149,8 @@ private[spark] class KubernetesExternalShuffleBlockHandler( callback.onSuccess(managedBuffer.nioByteBuffer()) } catch { case e: ExecutionException => logError(s"Unable to write index file $indexFile", e) + } finally { + responseDelayContext.stop() } case _ => super.handleMessage(message, client, callback) } @@ -122,20 +162,30 @@ private[spark] class KubernetesExternalShuffleBlockHandler( callback: RpcResponseCallback): StreamCallbackWithID = { header match { case UploadParam( - appId, shuffleId, mapId, partitionId) => - // TODO: Investigate whether we should use the partitionId for Index File creation - logInfo(s"Received upload param from app $appId") - getFileWriterStreamCallback( - appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) - case UploadIndexParam(appId, shuffleId, mapId) => - logInfo(s"Received upload index param from app $appId") - getFileWriterStreamCallback( - appId, shuffleId, mapId, "index", FileWriterStreamCallback.FileType.INDEX) + appId, shuffleId, mapId, partitionId, partitionLength) => + val responseDelayContext = metricSet.writeBlockRequestLatencyMillis.time() + try { + logInfo(s"Received upload param from app $appId") + val lengthMap = TreeMap(partitionId -> partitionLength.toLong) + globalPartitionLengths.merge((appId, shuffleId, mapId), lengthMap, + new BiFunction[TreeMap[Int, Long], TreeMap[Int, Long], TreeMap[Int, Long]]() { + override def apply(t: TreeMap[Int, Long], u: TreeMap[Int, Long]): + TreeMap[Int, Long] = { + t ++ u + } + }) + getFileWriterStreamCallback( + appId, shuffleId, mapId, "data", FileWriterStreamCallback.FileType.DATA) + } finally { + responseDelayContext.stop() + } case _ => super.handleStream(header, client, callback) } } + protected override def getAllMetrics: MetricSet = metricSet + private def getFileWriterStreamCallback( appId: String, shuffleId: Int, @@ -169,12 +219,17 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } private object UploadParam { - def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int)] = - Some((u.appId, u.shuffleId, u.mapId, u.partitionId)) + def unapply(u: UploadShufflePartitionStream): Option[(String, Int, Int, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId, u.partitionId, u.partitionLength)) } private object UploadIndexParam { - def unapply(u: UploadShuffleIndexStream): Option[(String, Int, Int)] = + def unapply(u: UploadShuffleIndex): Option[(String, Int, Int)] = + Some((u.appId, u.shuffleId, u.mapId)) + } + + private object RegisterIndexParam { + def unapply(u: RegisterShuffleIndex): Option[(String, Int, Int)] = Some((u.appId, u.shuffleId, u.mapId)) } @@ -204,6 +259,32 @@ private[spark] class KubernetesExternalShuffleBlockHandler( } } } + private class RemoteShuffleMetrics extends MetricSet { + private val allMetrics = new util.HashMap[String, Metric]() + // Time latency for write request in ms + private val _writeBlockRequestLatencyMillis = new Timer() + def writeBlockRequestLatencyMillis: Timer = _writeBlockRequestLatencyMillis + // Time latency for write index file in ms + private val _writeIndexRequestLatencyMillis = new Timer() + def writeIndexRequestLatencyMillis: Timer = _writeIndexRequestLatencyMillis + // Time latency for read request in ms + private val _openBlockRequestLatencyMillis = new Timer() + def openBlockRequestLatencyMillis: Timer = _openBlockRequestLatencyMillis + // Time latency for executor registration latency in ms + private val _registerDriverRequestLatencyMillis = new Timer() + def registerDriverRequestLatencyMillis: Timer = _registerDriverRequestLatencyMillis + // Block transfer rate in byte per second + private val _blockTransferRateBytes = new Meter() + def blockTransferRateBytes: Meter = _blockTransferRateBytes + + allMetrics.put("writeBlockRequestLatencyMillis", _writeBlockRequestLatencyMillis) + allMetrics.put("writeIndexRequestLatencyMillis", _writeIndexRequestLatencyMillis) + allMetrics.put("openBlockRequestLatencyMillis", _openBlockRequestLatencyMillis) + allMetrics.put("registerDriverRequestLatencyMillis", _registerDriverRequestLatencyMillis) + allMetrics.put("blockTransferRateBytes", _blockTransferRateBytes) + override def getMetrics: util.Map[String, Metric] = allMetrics + } + } /** From 45343fad7e903c8be874de71c8dc8c1aeb87c09a Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 15 Jan 2019 18:24:13 -0800 Subject: [PATCH 18/30] fix serialization --- .../external/ExternalShuffleLocation.java | 34 ++++++++++++++++--- .../apache/spark/scheduler/MapStatus.scala | 4 +-- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java index 9c1b8db027a0..3df7aded2fc4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -1,17 +1,18 @@ package org.apache.spark.shuffle.external; +import org.apache.hadoop.mapreduce.task.reduce.Shuffle; import org.apache.spark.network.protocol.Encoders; import org.apache.spark.storage.ShuffleLocation; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; +import java.io.*; public class ExternalShuffleLocation implements ShuffleLocation { private String shuffleHostname; private int shufflePort; + public ExternalShuffleLocation() { /* for serialization */ } + public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { this.shuffleHostname = shuffleHostname; this.shufflePort = shufflePort; @@ -19,13 +20,19 @@ public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { @Override public void writeExternal(ObjectOutput out) throws IOException { - out.writeObject(shuffleHostname); +// out.writeInt(shuffleHostname.length()); +// out.writeChars(shuffleHostname); + out.writeUTF(shuffleHostname); out.writeInt(shufflePort); } @Override public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { - this.shuffleHostname = (String) in.readObject(); +// int size = in.readInt(); +// byte[] buf = new byte[size]; +// in.read(buf, 0, size); +// this.shuffleHostname = new String(buf); + this.shuffleHostname = in.readUTF(); this.shufflePort = in.readInt(); } @@ -36,4 +43,21 @@ public String getShuffleHostname() { public int getShufflePort() { return this.shufflePort; } + + + public static void main(String[] args) throws IOException, ClassNotFoundException { + ExternalShuffleLocation externalShuffleLocation = new ExternalShuffleLocation("hostname", 1234); + ShuffleLocation shuffleLocation = (ShuffleLocation) externalShuffleLocation; + + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(shuffleLocation); + oos.flush(); + + + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + ObjectInputStream ois = new ObjectInputStream(bais); + ShuffleLocation newShuffLocation = (ShuffleLocation) ois.readObject(); + System.out.println(newShuffLocation); + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 6b262c27fa72..f259b5a44f74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -144,7 +144,7 @@ private[spark] class CompressedMapStatus( out.write(compressedSizes) if (shuffleLoc.isDefined) { out.writeBoolean(true) - shuffleLoc.get.writeExternal(out) + out.writeObject(shuffleLocation.get) } else { out.writeBoolean(false) } @@ -217,7 +217,7 @@ private[spark] class HighlyCompressedMapStatus private ( } if (shuffleLoc.isDefined) { out.writeBoolean(true) - shuffleLoc.get.writeExternal(out) + out.writeObject(shuffleLoc.get) } else { out.writeBoolean(false) } From a301d24c5621c3ffc35edc9442a5c9f145971985 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 16 Jan 2019 13:43:31 -0800 Subject: [PATCH 19/30] basic cleanup --- .../spark/shuffle/external/ExternalShuffleDataIO.java | 2 +- .../shuffle/external/ExternalShuffleReadSupport.java | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 0eae7e7958a5..440fa890bcf3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -44,7 +44,7 @@ public void initialize() { public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( conf, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port, mapOutputTracker); + securityManager, mapOutputTracker); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index b98eb3b10916..a99d23f179b4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -26,22 +26,16 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private final TransportConf conf; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; - private final String hostName; - private final int port; private final MapOutputTracker mapOutputTracker; public ExternalShuffleReadSupport( TransportConf conf, boolean authEnabled, SecretKeyHolder secretKeyHolder, - String hostName, - int port, MapOutputTracker mapOutputTracker) { this.conf = conf; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; - this.hostName = hostName; - this.port = port; this.mapOutputTracker = mapOutputTracker; } @@ -57,6 +51,9 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId)); assert maybeShuffleLocation.isPresent(); ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get(); + logger.info(String.format("Found external shuffle location on node: %s:%d", + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort())); try { return new ExternalShufflePartitionReader(clientFactory, externalShuffleLocation.getShuffleHostname(), From 3ba25ab91a028ee731028c57d5e2034b10c718bd Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 17 Jan 2019 14:31:30 -0800 Subject: [PATCH 20/30] compiles --- .../shuffle/protocol/UploadShuffleIndex.java | 4 +- .../spark/shuffle/api/CommittedPartition.java | 23 ++++++ .../shuffle/api/ShuffleMapOutputWriter.java | 6 -- .../shuffle/api/ShufflePartitionWriter.java | 7 +- .../external/ExternalCommittedPartition.java | 32 ++++++++ .../ExternalShuffleMapOutputWriter.java | 7 -- .../ExternalShufflePartitionReader.java | 3 +- .../ExternalShufflePartitionWriter.java | 7 +- .../external/ExternalShuffleReadSupport.java | 4 +- .../sort/BypassMergeSortShuffleWriter.java | 44 +++++------ .../shuffle/sort/UnsafeShuffleWriter.java | 21 +++-- .../org/apache/spark/MapOutputTracker.scala | 12 +-- .../apache/spark/scheduler/MapStatus.scala | 78 +++++++++---------- .../shuffle/sort/SortShuffleWriter.scala | 5 +- .../ShufflePartitionObjectWriter.scala | 4 +- .../util/collection/ExternalSorter.scala | 5 +- .../sort/UnsafeShuffleWriterSuite.java | 20 +++-- .../apache/spark/MapOutputTrackerSuite.scala | 31 ++++---- .../apache/spark/SplitFilesShuffleIO.scala | 14 ++-- .../serializer/KryoSerializerSuite.scala | 5 +- .../spark/examples/GroupByShuffleTest.scala | 1 + 21 files changed, 191 insertions(+), 142 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java create mode 100644 core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java index 374b399621aa..b11a02f6b921 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/UploadShuffleIndex.java @@ -43,8 +43,8 @@ public UploadShuffleIndex( @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadShufflePartitionStream) { - UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + if (other != null && other instanceof UploadShuffleIndex) { + UploadShuffleIndex o = (UploadShuffleIndex) other; return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId; diff --git a/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java new file mode 100644 index 000000000000..7846fad70b15 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/CommittedPartition.java @@ -0,0 +1,23 @@ +package org.apache.spark.shuffle.api; + +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public interface CommittedPartition { + + /** + * Indicates the number of bytes written in a committed partition. + * Note that returning the length is mainly for backwards compatibility + * and should be removed in a more polished variant. After this method + * is called, the writer will be discarded; it's expected that the + * implementation will close any underlying resources. + */ + long length(); + + /** + * Indicates the shuffle location to which this partition was written. + * Some implementations may not need to specify a shuffle location. + */ + Optional shuffleLocation(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 3d3a3acabc75..06415dba72d3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -17,17 +17,11 @@ package org.apache.spark.shuffle.api; -import org.apache.spark.storage.ShuffleLocation; - -import java.util.Optional; - public interface ShuffleMapOutputWriter { ShufflePartitionWriter newPartitionWriter(int partitionId); void commitAllPartitions(); - Optional getShuffleLocation(); - void abort(Exception exception); } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index ae9ada03e760..bdc0fd45474f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -31,12 +31,9 @@ public interface ShufflePartitionWriter { /** * Indicate that the partition was written successfully and there are no more incoming bytes. - * Returns the length of the partition that is written. Note that returning the length is - * mainly for backwards compatibility and should be removed in a more polished variant. - * After this method is called, the writer will be discarded; it's expected that the - * implementation will close any underlying resources. + * Returns a {@link CommittedPartition} indicating information about that written partition. */ - long commitAndGetTotalLength(); + CommittedPartition commitPartition(); /** * Indicate that the write has failed for some reason and the implementation can handle the diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java new file mode 100644 index 000000000000..7e37659dbb3f --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalCommittedPartition.java @@ -0,0 +1,32 @@ +package org.apache.spark.shuffle.external; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class ExternalCommittedPartition implements CommittedPartition { + + private final long length; + private final Optional shuffleLocation; + + public ExternalCommittedPartition(long length) { + this.length = length; + this.shuffleLocation = Optional.empty(); + } + + public ExternalCommittedPartition(long length, ShuffleLocation shuffleLocation) { + this.length = length; + this.shuffleLocation = Optional.of(shuffleLocation); + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return shuffleLocation; + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index 142dc4f4f8d9..fd23772d89e7 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -6,12 +6,10 @@ import org.apache.spark.network.shuffle.protocol.UploadShuffleIndex; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.nio.ByteBuffer; -import java.util.Optional; public class ExternalShuffleMapOutputWriter implements ShuffleMapOutputWriter { @@ -88,11 +86,6 @@ public void commitAllPartitions() { } } - @Override - public Optional getShuffleLocation() { - return Optional.of(new ExternalShuffleLocation(hostName, port)); - } - @Override public void abort(Exception exception) { clientFactory.close(); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index ee363f8bb41b..8aefac239e97 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -8,7 +8,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.*; +import java.io.ByteArrayInputStream; +import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Arrays; diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index d9b7d7ac515d..ef0c4842b17a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -6,13 +6,16 @@ import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Optional; public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { @@ -51,7 +54,7 @@ public ExternalShufflePartitionWriter( public OutputStream openPartitionStream() { return partitionBuffer; } @Override - public long commitAndGetTotalLength() { + public CommittedPartition commitPartition() { RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -88,7 +91,7 @@ public void onFailure(Throwable e) { } finally { logger.info("Successfully sent partition to ESS"); } - return totalLength; + return new ExternalCommittedPartition(totalLength, new ExternalShuffleLocation(hostName, port)); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index f3d402488d87..9e7ff55f4774 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -48,13 +48,13 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } - TransportClientFactory clientFactory = context.createClientFactory(bootstraps); - Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId)); + Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId, 0)); assert maybeShuffleLocation.isPresent(); ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get(); logger.info(String.format("Found external shuffle location on node: %s:%d", externalShuffleLocation.getShuffleHostname(), externalShuffleLocation.getShufflePort())); + TransportClientFactory clientFactory = context.createClientFactory(bootstraps); try { return new ExternalShufflePartitionReader(clientFactory, externalShuffleLocation.getShuffleHostname(), diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index ec9d413313f7..d90a05ab9a20 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -17,24 +17,8 @@ package org.apache.spark.shuffle.sort; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import javax.annotation.Nullable; - -import scala.None$; -import scala.Option; -import scala.Product2; -import scala.Tuple2; -import scala.collection.Iterator; - import com.google.common.annotations.VisibleForTesting; import com.google.common.io.Closeables; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; @@ -42,15 +26,25 @@ import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; -import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; import org.apache.spark.shuffle.ShuffleWriter; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; import org.apache.spark.shuffle.api.ShufflePartitionWriter; import org.apache.spark.shuffle.api.ShuffleWriteSupport; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; -import scala.compat.java8.OptionConverters; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import scala.None$; +import scala.Option; +import scala.Product2; +import scala.Tuple2; +import scala.collection.Iterator; + +import javax.annotation.Nullable; +import java.io.*; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -96,7 +90,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; private long[] partitionLengths; - private Option shuffleLocation = Option.empty(); + private ShuffleLocation[] shuffleLocations; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -134,8 +128,9 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { partitionLengths = new long[numPartitions]; + shuffleLocations = new ShuffleLocation[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -181,7 +176,7 @@ public void write(Iterator> records) throws IOException { } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); } @VisibleForTesting @@ -253,7 +248,11 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio try (OutputStream out = writer.openPartitionStream()) { Utils.copyStream(in, out, false, false); } - lengths[i] = writer.commitAndGetTotalLength(); + CommittedPartition committedPartition = writer.commitPartition(); + lengths[i] = committedPartition.length(); + if (committedPartition.shuffleLocation().isPresent()) { + shuffleLocations[i] = committedPartition.shuffleLocation().get(); + } copyThrewException = false; } catch (Exception e) { try { @@ -270,7 +269,6 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio } } mapOutputWriter.commitAllPartitions(); - shuffleLocation = OptionConverters.toScala(mapOutputWriter.getShuffleLocation()); } catch (Exception e) { try { mapOutputWriter.abort(e); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 80fd1ce312ad..5cb9d391f060 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -22,11 +22,11 @@ import java.nio.channels.FileChannel; import java.util.Iterator; +import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.storage.ShuffleLocation; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; -import scala.compat.java8.OptionConverters; import scala.reflect.ClassTag; import scala.reflect.ClassTag$; @@ -91,7 +91,7 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; - private Option shuffleLocation = Option.empty(); + private ShuffleLocation[] shuffleLocations; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -158,6 +158,7 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.outputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.shuffleLocations = new ShuffleLocation[numPartitions]; open(); } @@ -260,7 +261,7 @@ void closeAndWriteOutput() throws IOException { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocation); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); } @VisibleForTesting @@ -307,6 +308,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti return new long[partitioner.numPartitions()]; } else if (spills.length == 1) { if (pluggableWriteSupport != null) { + // TODO: should this be returning a partition length? writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); } else { // Here, we don't need to perform any metrics updates because the bytes written to this @@ -555,7 +557,11 @@ private long[] mergeSpillsWithPluggableWriter( } } } - partitionLengths[partition] = writer.commitAndGetTotalLength(); + CommittedPartition committedPartition = writer.commitPartition(); + if (committedPartition.shuffleLocation().isPresent()) { + shuffleLocations[partition] = committedPartition.shuffleLocation().get(); + } + partitionLengths[partition] = committedPartition.length(); writeMetrics.incBytesWritten(partitionLengths[partition]); } catch (Exception e) { try { @@ -567,7 +573,6 @@ private long[] mergeSpillsWithPluggableWriter( } } mapOutputWriter.commitAllPartitions(); - shuffleLocation = OptionConverters.toScala(mapOutputWriter.getShuffleLocation()); threwException = false; } catch (Exception e) { try { @@ -621,7 +626,11 @@ private void writeSingleSpillFileUsingPluggableWriter( } finally { partitionInputStream.close(); } - writeMetrics.incBytesWritten(writer.commitAndGetTotalLength()); + CommittedPartition committedPartition = writer.commitPartition(); + if (committedPartition.shuffleLocation().isPresent()) { + shuffleLocations[partition] = committedPartition.shuffleLocation().get(); + } + writeMetrics.incBytesWritten(committedPartition.length()); } threwException = false; } catch (Exception e) { diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 43696f1e5fa5..b340ffbbc43d 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -303,7 +303,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long)])] - def getShuffleLocation(shuffleId: Int, mapId: Int) : Option[ShuffleLocation] + def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation] /** * Deletes map output status information for the specified shuffle stage. @@ -679,9 +679,10 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.clear() } - override def getShuffleLocation(shuffleId: Int, mapId: Int): Option[ShuffleLocation] = { + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { shuffleStatuses.get(shuffleId) match { - case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocation + case Some(shuffleStatus) => shuffleStatus.mapStatuses(mapId).shuffleLocationForBlock(reduceId) case None => Option.empty } } @@ -799,9 +800,10 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - override def getShuffleLocation(shuffleId: Int, mapId: Int): Option[ShuffleLocation] = { + override def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int): + Option[ShuffleLocation] = { mapStatuses.get(shuffleId) match { - case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocation + case Some(shuffleStatus) => shuffleStatus(mapId).shuffleLocationForBlock(reduceId) case None => Option.empty } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index f259b5a44f74..ede997cbf64a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -19,9 +19,8 @@ package org.apache.spark.scheduler import java.io.{Externalizable, ObjectInput, ObjectOutput} -import scala.collection.mutable - import org.roaringbitmap.RoaringBitmap +import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.config @@ -36,7 +35,7 @@ private[spark] sealed trait MapStatus { /** Location where this task was run. */ def location: BlockManagerId - def shuffleLocation: Option[ShuffleLocation] + def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] /** * Estimated size for the reduce block, in bytes. @@ -59,19 +58,20 @@ private[spark] object MapStatus { .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], - shuffleLocation: Option[ShuffleLocation]): MapStatus = { + shuffleLocations: Array[ShuffleLocation]): MapStatus = { + assert(uncompressedSizes.length == shuffleLocations.length) if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, shuffleLocation) + HighlyCompressedMapStatus(loc, uncompressedSizes, shuffleLocations) } else { - new CompressedMapStatus(loc, uncompressedSizes, shuffleLocation) + new CompressedMapStatus(loc, uncompressedSizes, shuffleLocations) } } def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, Option.empty) + HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } else { - new CompressedMapStatus(loc, uncompressedSizes, Option.empty) + new CompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) } } @@ -115,24 +115,26 @@ private[spark] object MapStatus { private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var shuffleLoc: Option[ShuffleLocation]) + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { // For deserialization only protected def this() = this(null, null.asInstanceOf[Array[Byte]], null) def this(loc: BlockManagerId, uncompressedSizes: Array[Long], - shuffleLoc: Option[ShuffleLocation]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLoc) - } - - def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) { - this(loc, uncompressedSizes.map(MapStatus.compressSize), Option.empty) + shuffleLocations: Array[ShuffleLocation]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), shuffleLocations) } override def location: BlockManagerId = loc - override def shuffleLocation: Option[ShuffleLocation] = shuffleLoc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } override def getSizeForBlock(reduceId: Int): Long = { MapStatus.decompressSize(compressedSizes(reduceId)) @@ -142,12 +144,7 @@ private[spark] class CompressedMapStatus( loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) - if (shuffleLoc.isDefined) { - out.writeBoolean(true) - out.writeObject(shuffleLocation.get) - } else { - out.writeBoolean(false) - } + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -155,12 +152,7 @@ private[spark] class CompressedMapStatus( val len = in.readInt() compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) - val shuffleLocationExists = in.readBoolean() - if (shuffleLocationExists) { - shuffleLoc = Option.apply(in.readObject().asInstanceOf[ShuffleLocation]) - } else { - shuffleLoc = Option.empty - } + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } @@ -181,7 +173,7 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var shuffleLoc: Option[ShuffleLocation]) + private[this] var shuffleLocations: Array[ShuffleLocation]) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -192,7 +184,13 @@ private[spark] class HighlyCompressedMapStatus private ( override def location: BlockManagerId = loc - override def shuffleLocation: Option[ShuffleLocation] = shuffleLoc + override def shuffleLocationForBlock(reduceId: Int): Option[ShuffleLocation] = { + if (shuffleLocations.apply(reduceId) == null) { + Option.empty + } else { + Option.apply(shuffleLocations.apply(reduceId)) + } + } override def getSizeForBlock(reduceId: Int): Long = { assert(hugeBlockSizes != null) @@ -215,12 +213,7 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeInt(kv._1) out.writeByte(kv._2) } - if (shuffleLoc.isDefined) { - out.writeBoolean(true) - out.writeObject(shuffleLoc.get) - } else { - out.writeBoolean(false) - } + out.writeObject(shuffleLocations) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -236,18 +229,17 @@ private[spark] class HighlyCompressedMapStatus private ( hugeBlockSizesImpl(block) = size } hugeBlockSizes = hugeBlockSizesImpl - val shuffleLocationExists = in.readBoolean() - if (shuffleLocationExists) { - shuffleLoc = Option.apply(in.readObject().asInstanceOf[ShuffleLocation]) - } else { - shuffleLoc = Option.empty - } + shuffleLocations = in.readObject().asInstanceOf[Array[ShuffleLocation]] } } private[spark] object HighlyCompressedMapStatus { + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = { + apply(loc, uncompressedSizes, Array.empty[ShuffleLocation]) + } + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], - shuffleLocation: Option[ShuffleLocation]): HighlyCompressedMapStatus = { + shuffleLocation: Array[ShuffleLocation]): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 33736d56706d..6842a58c3820 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -49,8 +49,6 @@ private[spark] class SortShuffleWriter[K, V, C]( private val writeMetrics = context.taskMetrics().shuffleWriteMetrics - private val shuffleLocation = Option.empty - /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { sorter = if (dep.mapSideCombine) { @@ -72,13 +70,14 @@ private[spark] class SortShuffleWriter[K, V, C]( val tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) + // TODO: fix this, return committed partition val partitionLengths = pluggableWriteSupport.map { writeSupport => sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) if (pluggableWriteSupport.isEmpty) { shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) } - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, shuffleLocation) + mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala index b03276b2ce16..e69a8fc43a17 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala @@ -55,11 +55,11 @@ private[spark] class ShufflePartitionObjectWriter( require(objectOutputStream != null, "Cannot commit a partition that has not been started.") require(currentWriter != null, "Cannot commit a partition that has not been started.") objectOutputStream.close() - val length = currentWriter.commitAndGetTotalLength() + val length = currentWriter.commitPartition() buffer.reset() currentWriter = null objectOutputStream = null - length + length.length() // TODO: update this } def abortCurrentPartition(throwable: Exception): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 569c8bd092f3..e50a0cbdd909 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -22,15 +22,14 @@ import java.util.Comparator import scala.collection.mutable import scala.collection.mutable.ArrayBuffer - import com.google.common.io.ByteStreams import org.apache.spark.{util, _} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ -import org.apache.spark.shuffle.api.ShuffleWriteSupport -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShufflePartitionObjectWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport} +import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index ce3c1aea895c..164903377378 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -23,6 +23,7 @@ import java.nio.file.StandardOpenOption; import java.util.*; +import org.apache.spark.shuffle.api.CommittedPartition; import scala.Option; import scala.Product2; import scala.Tuple2; @@ -675,7 +676,7 @@ public OutputStream openPartitionStream() { } @Override - public long commitAndGetTotalLength() { + public CommittedPartition commitPartition() { byte[] partitionBytes = byteBuffer.toByteArray(); try { Files.write(mergedOutputFile.toPath(), partitionBytes, StandardOpenOption.APPEND); @@ -684,7 +685,17 @@ public long commitAndGetTotalLength() { } int length = partitionBytes.length; partitionSizesInMergedFile[partitionId] = length; - return length; + return new CommittedPartition() { + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return Optional.empty(); + } + }; } @Override @@ -700,11 +711,6 @@ public void commitAllPartitions() { } - @Override - public Optional getShuffleLocation() { - return Optional.empty(); - } - @Override public void abort(Exception failureReason) { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 9506c86cdd5d..f3512b064c23 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer - import org.mockito.Matchers.any import org.mockito.Mockito._ @@ -27,7 +26,7 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleLocation} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -62,9 +61,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), Option.empty)) + Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), Option.empty)) + Array(10000L, 1000L))) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), @@ -84,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000), Option.empty)) + Array(compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000), Option.empty)) + Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -107,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000), Option.empty)) + Array(compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000), Option.empty)) + Array(compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -145,7 +144,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), Option.empty)) + BlockManagerId("a", "hostA", 1000), Array(1000L))) slaveTracker.updateEpoch(masterTracker.getEpoch) assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) @@ -182,7 +181,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // Message size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), Option.empty)) + BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) when(rpcCallContext.senderAddress).thenReturn(senderAddress) @@ -216,11 +215,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { // on hostB with output size 3 tracker.registerShuffle(10, 3) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), Option.empty)) + Array(2L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(2L), Option.empty)) + Array(2L))) tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(3L), Option.empty)) + Array(3L))) // When the threshold is 50%, only host A should be returned as a preferred location // as it has 4 out of 7 bytes of output. @@ -260,7 +259,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), Array.empty[ShuffleLocation])) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) @@ -309,9 +308,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), Option.empty)) + Array(size0, size1000, size0, size10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), Option.empty)) + Array(size10000, size0, size1000, size0))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 4).toSeq === Seq( diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index 68a463bc7cef..097d1e406dc0 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.io._ import java.nio.file.Paths import java.util.Optional - import javax.ws.rs.core.UriBuilder import org.apache.spark.shuffle.api._ @@ -52,8 +51,14 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { new FileOutputStream(shuffleFile) } - override def commitAndGetTotalLength(): Long = - resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + override def commitPartition(): CommittedPartition = { + new CommittedPartition { + override def length(): Long = + resolvePartitionFile(appId, shuffleId, mapId, partitionId).length + + override def shuffleLocation(): Optional[ShuffleLocation] = Optional.empty() + } + } override def abort(failureReason: Exception): Unit = {} } @@ -62,14 +67,11 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { override def commitAllPartitions(): Unit = {} override def abort(exception: Exception): Unit = {} - - override def getShuffleLocation: Optional[ShuffleLocation] = Optional.empty() } } private def resolvePartitionFile( appId: String, shuffleId: Int, mapId: Int, reduceId: Int): File = { - import java.io.OutputStream Paths.get(UriBuilder.fromUri(shuffleDir.toURI) .path(appId) .path(shuffleId.toString) diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 6e89ab710206..7040c632d3c5 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -26,7 +26,6 @@ import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag - import com.esotericsoftware.kryo.{Kryo, KryoException} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import org.roaringbitmap.RoaringBitmap @@ -34,7 +33,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -350,7 +349,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => ser.serialize(HighlyCompressedMapStatus( - BlockManagerId("exec-1", "host", 1234), blockSizes, Option.empty)) + BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala index 883ac10718df..41d8150acc19 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -48,6 +48,7 @@ object GroupByShuffleTest { .collect() println(wordCountsWithGroup2.mkString(",")) + Thread.sleep(100000) spark.stop() } From d7919f2d4a47116731b645da56ebbaaabf1e392b Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Thu, 17 Jan 2019 15:58:19 -0800 Subject: [PATCH 21/30] Bypass Merge sort works --- .../network/shuffle/protocol/RegisterShuffleIndex.java | 4 ++-- .../external/ExternalShuffleMapOutputWriter.java | 2 +- .../external/ExternalShufflePartitionWriter.java | 10 ++++------ .../shuffle/sort/BypassMergeSortShuffleWriter.java | 2 +- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java index 27f101171834..bc870e440274 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/RegisterShuffleIndex.java @@ -43,8 +43,8 @@ public RegisterShuffleIndex( @Override public boolean equals(Object other) { - if (other != null && other instanceof UploadShufflePartitionStream) { - UploadShufflePartitionStream o = (UploadShufflePartitionStream) other; + if (other != null && other instanceof RegisterShuffleIndex) { + RegisterShuffleIndex o = (RegisterShuffleIndex) other; return Objects.equal(appId, o.appId) && shuffleId == o.shuffleId && mapId == o.mapId; diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java index fd23772d89e7..8866d14feca5 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleMapOutputWriter.java @@ -80,8 +80,8 @@ public void commitAllPartitions() { logger.info("clientid: " + client.getClientId() + " " + client.isActive()); client.sendRpcSync(uploadShuffleIndex, 60000); } catch (Exception e) { - client.close(); logger.error("Encountered error while creating transport client", e); + client.close(); throw new RuntimeException(e); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index ef0c4842b17a..39595954dafb 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -22,7 +22,7 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); - private final TransportClientFactory clientFactory; + private final TransportClient client; private final String hostName; private final int port; private final String appId; @@ -40,8 +40,8 @@ public ExternalShufflePartitionWriter( String appId, int shuffleId, int mapId, - int partitionId) { - this.clientFactory = clientFactory; + int partitionId) throws IOException, InterruptedException { + this.client = clientFactory.createUnmanagedClient(hostName, port); this.hostName = hostName; this.port = port; this.appId = appId; @@ -66,14 +66,13 @@ public void onFailure(Throwable e) { logger.error("Encountered an error uploading partition", e); } }; - TransportClient client = null; try { byte[] buf = partitionBuffer.toByteArray(); int size = buf.length; ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId, size).toByteBuffer(); ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); - client = clientFactory.createUnmanagedClient(hostName, port); + client.setClientId(String.format("data-%s-%d-%d-%d", appId, shuffleId, mapId, partitionId)); logger.info("clientid: " + client.getClientId() + " " + client.isActive()); @@ -96,7 +95,6 @@ public void onFailure(Throwable e) { @Override public void abort(Exception failureReason) { - clientFactory.close(); try { this.partitionBuffer.close(); } catch(IOException e) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index d90a05ab9a20..4a643e04b66f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -121,6 +121,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleBlockResolver = shuffleBlockResolver; this.pluggableWriteSupport = pluggableWriteSupport; this.appId = conf.getAppId(); + this.shuffleLocations = new ShuffleLocation[numPartitions]; } @Override @@ -128,7 +129,6 @@ public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { partitionLengths = new long[numPartitions]; - shuffleLocations = new ShuffleLocation[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); return; From 0befe41116e334f436eb4202e1bb0b5bd4280680 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 18 Jan 2019 10:22:24 -0800 Subject: [PATCH 22/30] small refactors --- .../ExternalShufflePartitionWriter.java | 9 ++++--- .../shuffle/sort/LocalCommittedPartition.java | 25 +++++++++++++++++ .../apache/spark/scheduler/MapStatus.scala | 15 +++++++++++ .../shuffle/sort/SortShuffleWriter.scala | 9 ++++--- .../ShufflePartitionObjectWriter.scala | 8 +++--- .../util/collection/ExternalSorter.scala | 27 ++++++++++--------- .../apache/spark/MapOutputTrackerSuite.scala | 11 ++++---- 7 files changed, 75 insertions(+), 29 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 39595954dafb..89bfe4407e5a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -22,7 +22,7 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private static final Logger logger = LoggerFactory.getLogger(ExternalShufflePartitionWriter.class); - private final TransportClient client; + private final TransportClientFactory clientFactory; private final String hostName; private final int port; private final String appId; @@ -40,8 +40,8 @@ public ExternalShufflePartitionWriter( String appId, int shuffleId, int mapId, - int partitionId) throws IOException, InterruptedException { - this.client = clientFactory.createUnmanagedClient(hostName, port); + int partitionId) { + this.clientFactory = clientFactory; this.hostName = hostName; this.port = port; this.appId = appId; @@ -66,13 +66,14 @@ public void onFailure(Throwable e) { logger.error("Encountered an error uploading partition", e); } }; + TransportClient client = null; try { byte[] buf = partitionBuffer.toByteArray(); int size = buf.length; ByteBuffer streamHeader = new UploadShufflePartitionStream(appId, shuffleId, mapId, partitionId, size).toByteBuffer(); ManagedBuffer managedBuffer = new NioManagedBuffer(ByteBuffer.wrap(buf)); - + client = clientFactory.createUnmanagedClient(hostName, port); client.setClientId(String.format("data-%s-%d-%d-%d", appId, shuffleId, mapId, partitionId)); logger.info("clientid: " + client.getClientId() + " " + client.isActive()); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java new file mode 100644 index 000000000000..817855d95796 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/LocalCommittedPartition.java @@ -0,0 +1,25 @@ +package org.apache.spark.shuffle.sort; + +import org.apache.spark.shuffle.api.CommittedPartition; +import org.apache.spark.storage.ShuffleLocation; + +import java.util.Optional; + +public class LocalCommittedPartition implements CommittedPartition { + + private final long length; + + public LocalCommittedPartition(long length) { + this.length = length; + } + + @Override + public long length() { + return length; + } + + @Override + public Optional shuffleLocation() { + return Optional.empty(); + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index ede997cbf64a..ac6cfc5e870f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -24,6 +24,7 @@ import scala.collection.mutable import org.apache.spark.SparkEnv import org.apache.spark.internal.config +import org.apache.spark.shuffle.api.CommittedPartition import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils @@ -57,6 +58,20 @@ private[spark] object MapStatus { .map(_.conf.get(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS)) .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) + def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = { + val shuffleLocationsArray = committedPartitions.map(a => { + a.shuffleLocation() match { + case empty if empty.isPresent => empty.get() + case _ => null + } + }) + if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, committedPartitions.map(_.length()), shuffleLocationsArray) + } else { + new CompressedMapStatus(loc, committedPartitions.map(_.length()), shuffleLocationsArray) + } + } + def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], shuffleLocations: Array[ShuffleLocation]): MapStatus = { assert(uncompressedSizes.length == shuffleLocations.length) diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 6842a58c3820..d3afe31eae7c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -71,13 +71,16 @@ private[spark] class SortShuffleWriter[K, V, C]( try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) // TODO: fix this, return committed partition - val partitionLengths = pluggableWriteSupport.map { writeSupport => + val committedPartitions = pluggableWriteSupport.map { writeSupport => sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) if (pluggableWriteSupport.isEmpty) { - shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp) + shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, + mapId, + committedPartitions.map(_.length()), + tmp) } - mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths) + mapStatus = MapStatus(blockManager.shuffleServerId, committedPartitions) } finally { if (tmp.exists() && !tmp.delete()) { logError(s"Error while deleting temp file ${tmp.getAbsolutePath}") diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala index e69a8fc43a17..baaee46f8123 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.serializer.{SerializationStream, SerializerInstance} import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream -import org.apache.spark.shuffle.api.{ShuffleMapOutputWriter, ShufflePartitionWriter} +import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter} /** * Replicates the concept of {@link DiskBlockObjectWriter}, but with some key differences: @@ -51,15 +51,15 @@ private[spark] class ShufflePartitionObjectWriter( objectOutputStream = serializerInstance.serializeStream(currentWriterStream) } - def commitCurrentPartition(): Long = { + def commitCurrentPartition(): CommittedPartition = { require(objectOutputStream != null, "Cannot commit a partition that has not been started.") require(currentWriter != null, "Cannot commit a partition that has not been started.") objectOutputStream.close() - val length = currentWriter.commitPartition() + val committedPartition = currentWriter.commitPartition() buffer.reset() currentWriter = null objectOutputStream = null - length.length() // TODO: update this + committedPartition } def abortCurrentPartition(throwable: Exception): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index e50a0cbdd909..69077c644dc7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -18,18 +18,19 @@ package org.apache.spark.util.collection import java.io._ -import java.util.Comparator +import java.util.{Comparator, Optional} +import com.google.common.io.ByteStreams import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.google.common.io.ByteStreams -import org.apache.spark.{util, _} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport} +import org.apache.spark.shuffle.sort.LocalCommittedPartition import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter} +import org.apache.spark.{util, _} /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -682,10 +683,10 @@ private[spark] class ExternalSorter[K, V, C]( */ def writePartitionedFile( blockId: BlockId, - outputFile: File): Array[Long] = { + outputFile: File): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) + val committedPartitions = new Array[CommittedPartition](numPartitions) val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, context.taskMetrics().shuffleWriteMetrics) @@ -699,7 +700,7 @@ private[spark] class ExternalSorter[K, V, C]( it.writeNext(writer) } val segment = writer.commitAndGet() - lengths(partitionId) = segment.length + committedPartitions(partitionId) = new LocalCommittedPartition(segment.length) } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. @@ -709,7 +710,7 @@ private[spark] class ExternalSorter[K, V, C]( writer.write(elem._1, elem._2) } val segment = writer.commitAndGet() - lengths(id) = segment.length + committedPartitions(id) = new LocalCommittedPartition(segment.length) } } } @@ -719,17 +720,17 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } /** * Write all partitions to some backend that is pluggable. */ def writePartitionedToExternalShuffleWriteSupport( - mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[Long] = { + mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = { // Track location of each range in the output file - val lengths = new Array[Long](numPartitions) + val committedPartitions = new Array[CommittedPartition](numPartitions) val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId) val writer = new ShufflePartitionObjectWriter( Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt, @@ -748,7 +749,7 @@ private[spark] class ExternalSorter[K, V, C]( while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) } - lengths(partitionId) = writer.commitCurrentPartition() + committedPartitions(partitionId) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -766,7 +767,7 @@ private[spark] class ExternalSorter[K, V, C]( for (elem <- elements) { writer.write(elem._1, elem._2) } - lengths(id) = writer.commitCurrentPartition() + committedPartitions(id) = writer.commitCurrentPartition() } catch { case e: Exception => util.Utils.tryLogNonFatalError { @@ -790,7 +791,7 @@ private[spark] class ExternalSorter[K, V, C]( context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) - lengths + committedPartitions } def stop(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index f3512b064c23..90f6c3523ece 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -83,9 +83,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize10000))) + Array[Long](compressedSize1000, compressedSize10000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) @@ -106,9 +106,9 @@ class MapOutputTrackerSuite extends SparkFunSuite { val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(compressedSize1000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize1000, compressedSize1000, compressedSize1000))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(compressedSize10000, compressedSize1000, compressedSize1000))) + Array[Long](compressedSize10000, compressedSize1000, compressedSize1000))) assert(0 == tracker.getNumCachedSerializedBroadcast) // As if we had two simultaneous fetch failures @@ -259,7 +259,8 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(20, 100) (0 until 100).foreach { i => masterTracker.registerMapOutput(20, i, new CompressedMapStatus( - BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), Array.empty[ShuffleLocation])) + BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), + Array.empty[ShuffleLocation])) } val senderAddress = RpcAddress("localhost", 12345) val rpcCallContext = mock(classOf[RpcCallContext]) From 4fba8d2e526cad7c5c4b25bb15585cc6dc94d424 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 18 Jan 2019 13:44:07 -0800 Subject: [PATCH 23/30] done refactoring --- .../external/ExternalShuffleLocation.java | 25 +------ .../sort/BypassMergeSortShuffleWriter.java | 54 ++++++++------- .../shuffle/sort/UnsafeShuffleWriter.java | 68 ++++++++++--------- .../apache/spark/scheduler/MapStatus.scala | 21 ++++-- 4 files changed, 81 insertions(+), 87 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java index 3df7aded2fc4..20ae8d376050 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -20,18 +20,12 @@ public ExternalShuffleLocation(String shuffleHostname, int shufflePort) { @Override public void writeExternal(ObjectOutput out) throws IOException { -// out.writeInt(shuffleHostname.length()); -// out.writeChars(shuffleHostname); out.writeUTF(shuffleHostname); out.writeInt(shufflePort); } @Override - public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException { -// int size = in.readInt(); -// byte[] buf = new byte[size]; -// in.read(buf, 0, size); -// this.shuffleHostname = new String(buf); + public void readExternal(ObjectInput in) throws IOException { this.shuffleHostname = in.readUTF(); this.shufflePort = in.readInt(); } @@ -43,21 +37,4 @@ public String getShuffleHostname() { public int getShufflePort() { return this.shufflePort; } - - - public static void main(String[] args) throws IOException, ClassNotFoundException { - ExternalShuffleLocation externalShuffleLocation = new ExternalShuffleLocation("hostname", 1234); - ShuffleLocation shuffleLocation = (ShuffleLocation) externalShuffleLocation; - - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(shuffleLocation); - oos.flush(); - - - ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); - ObjectInputStream ois = new ObjectInputStream(bais); - ShuffleLocation newShuffLocation = (ShuffleLocation) ois.readObject(); - System.out.println(newShuffLocation); - } } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 4a643e04b66f..e33274a5f31d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -45,6 +45,8 @@ import javax.annotation.Nullable; import java.io.*; +import java.util.Arrays; +import java.util.stream.Collectors; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -89,8 +91,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { private DiskBlockObjectWriter[] partitionWriters; private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; - private long[] partitionLengths; - private ShuffleLocation[] shuffleLocations; + private CommittedPartition[] committedPartitions; /** * Are we in the process of stopping? Because map tasks can call stop() with success = true @@ -121,16 +122,16 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleBlockResolver = shuffleBlockResolver; this.pluggableWriteSupport = pluggableWriteSupport; this.appId = conf.getAppId(); - this.shuffleLocations = new ShuffleLocation[numPartitions]; +// this.committedPartitions = new CommittedPartition[numPartitions]; } @Override public void write(Iterator> records) throws IOException { assert (partitionWriters == null); if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; + long[] partitionLengths = new long[numPartitions]; shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); return; } final SerializerInstance serInstance = serializer.newInstance(); @@ -163,25 +164,33 @@ public void write(Iterator> records) throws IOException { } if (pluggableWriteSupport != null) { - partitionLengths = combineAndWritePartitionsUsingPluggableWriter(); + committedPartitions = combineAndWritePartitionsUsingPluggableWriter(); + logger.info("Successfully wrote partitions with pluggable writer"); } else { File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); File tmp = Utils.tempFileWith(output); try { - partitionLengths = combineAndWritePartitions(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + committedPartitions = combineAndWritePartitions(tmp); + logger.info("Successfully wrote partitions without shuffle"); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(), + tmp); } finally { if (tmp != null && tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); + logger.info("value of committedPartitions: " + committedPartitions); + logger.info("length of committed partitions:" + committedPartitions.length); + logger.info("length of committed partition value: " + committedPartitions[0].length()); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting long[] getPartitionLengths() { - return partitionLengths; + return Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(); } /** @@ -189,12 +198,12 @@ long[] getPartitionLengths() { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] combineAndWritePartitions(File outputFile) throws IOException { + private CommittedPartition[] combineAndWritePartitions(File outputFile) throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(outputFile != null); final FileOutputStream out = new FileOutputStream(outputFile, true); @@ -207,7 +216,8 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); + partitions[i] = + new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled)); copyThrewException = false; } finally { Closeables.close(in, copyThrewException); @@ -222,15 +232,15 @@ private long[] combineAndWritePartitions(File outputFile) throws IOException { writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } - private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { + private CommittedPartition[] combineAndWritePartitionsUsingPluggableWriter() throws IOException { // Track location of the partition starts in the output file - final long[] lengths = new long[numPartitions]; + final CommittedPartition[] partitions = new CommittedPartition[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator - return lengths; + return partitions; } assert(pluggableWriteSupport != null); @@ -248,11 +258,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio try (OutputStream out = writer.openPartitionStream()) { Utils.copyStream(in, out, false, false); } - CommittedPartition committedPartition = writer.commitPartition(); - lengths[i] = committedPartition.length(); - if (committedPartition.shuffleLocation().isPresent()) { - shuffleLocations[i] = committedPartition.shuffleLocation().get(); - } + partitions[i] = writer.commitPartition(); copyThrewException = false; } catch (Exception e) { try { @@ -280,7 +286,7 @@ private long[] combineAndWritePartitionsUsingPluggableWriter() throws IOExceptio writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; - return lengths; + return partitions; } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 5cb9d391f060..ef086e21b04d 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -20,7 +20,10 @@ import javax.annotation.Nullable; import java.io.*; import java.nio.channels.FileChannel; +import java.util.Arrays; import java.util.Iterator; +import java.util.stream.Collectors; +import java.util.stream.Stream; import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.storage.ShuffleLocation; @@ -91,7 +94,6 @@ public class UnsafeShuffleWriter extends ShuffleWriter { @Nullable private MapStatus mapStatus; @Nullable private ShuffleExternalSorter sorter; - private ShuffleLocation[] shuffleLocations; private long peakMemoryUsedBytes = 0; /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */ @@ -158,7 +160,6 @@ public UnsafeShuffleWriter( (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.outputBufferSizeInBytes = (int) (long) sparkConf.get(package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; - this.shuffleLocations = new ShuffleLocation[numPartitions]; open(); } @@ -240,12 +241,12 @@ void closeAndWriteOutput() throws IOException { serOutputStream = null; final SpillInfo[] spills = sorter.closeAndGetSpills(); sorter = null; - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); final File tmp = Utils.tempFileWith(output); try { try { - partitionLengths = mergeSpills(spills, tmp); + committedPartitions = mergeSpills(spills, tmp); } finally { for (SpillInfo spill : spills) { if (spill.file.exists() && ! spill.file.delete()) { @@ -254,14 +255,17 @@ void closeAndWriteOutput() throws IOException { } } if (pluggableWriteSupport == null) { - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); + shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, + mapId, + Arrays.stream(committedPartitions).mapToLong(CommittedPartition::length).toArray(), + tmp); } } finally { if (tmp.exists() && !tmp.delete()) { logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); } } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths, shuffleLocations); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } @VisibleForTesting @@ -293,7 +297,7 @@ void forceSorterToSpill() throws IOException { * * @return the partition lengths in the merged file. */ - private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { + private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException { final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true); final CompressionCodec compressionCodec = compressionEnabled ? CompressionCodec$.MODULE$.createCodec(sparkConf) : null; @@ -305,19 +309,18 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file - return new long[partitioner.numPartitions()]; + return new CommittedPartition[partitioner.numPartitions()]; } else if (spills.length == 1) { if (pluggableWriteSupport != null) { - // TODO: should this be returning a partition length? - writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); + return writeSingleSpillFileUsingPluggableWriter(spills[0], compressionCodec); } else { // Here, we don't need to perform any metrics updates because the bytes written to this // output file would have already been counted as shuffle bytes written. Files.move(spills[0].file, outputFile); } - return spills[0].partitionLengths; + return toLocalCommittedPartition(spills[0].partitionLengths); } else { - final long[] partitionLengths; + final CommittedPartition[] committedPartitions; // There are multiple spills to merge, so none of these spill files' lengths were counted // towards our shuffle write count or shuffle write time. If we use the slow merge path, // then the final output file's size won't necessarily be equal to the sum of the spill @@ -329,21 +332,21 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti // shuffle write time, which appears to be consistent with the "not bypassing merge-sort" // branch in ExternalSorter. if (pluggableWriteSupport != null) { - partitionLengths = mergeSpillsWithPluggableWriter(spills, compressionCodec); + committedPartitions = mergeSpillsWithPluggableWriter(spills, compressionCodec); } else if (fastMergeEnabled && fastMergeIsSupported) { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile)); } else { logger.debug("Using fileStream-based fast merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, null)); } } else { logger.debug("Using slow merge"); - partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec); + committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, compressionCodec)); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that @@ -354,7 +357,7 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti if (pluggableWriteSupport == null) { writeMetrics.incBytesWritten(outputFile.length()); } - return partitionLengths; + return committedPartitions; } } catch (IOException e) { if (outputFile.exists() && !outputFile.delete()) { @@ -364,6 +367,12 @@ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOExcepti } } + private static CommittedPartition[] toLocalCommittedPartition(long[] partitionLengths) { + return Arrays.stream(partitionLengths) + .mapToObj(length -> new LocalCommittedPartition(length)) + .collect(Collectors.toList()).toArray(new CommittedPartition[partitionLengths.length]); + } + /** * Merges spill files using Java FileStreams. This code path is typically slower than * the NIO-based merge, {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], @@ -517,13 +526,13 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th /** * Merges spill files using the ShufflePartitionWriter API. */ - private long[] mergeSpillsWithPluggableWriter( + private CommittedPartition[] mergeSpillsWithPluggableWriter( SpillInfo[] spills, @Nullable CompressionCodec compressionCodec) throws IOException { assert (spills.length >= 2); assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); - final long[] partitionLengths = new long[numPartitions]; + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; final InputStream[] spillInputStreams = new InputStream[spills.length]; boolean threwException = true; @@ -557,12 +566,8 @@ private long[] mergeSpillsWithPluggableWriter( } } } - CommittedPartition committedPartition = writer.commitPartition(); - if (committedPartition.shuffleLocation().isPresent()) { - shuffleLocations[partition] = committedPartition.shuffleLocation().get(); - } - partitionLengths[partition] = committedPartition.length(); - writeMetrics.incBytesWritten(partitionLengths[partition]); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } catch (Exception e) { try { writer.abort(e); @@ -588,14 +593,15 @@ private long[] mergeSpillsWithPluggableWriter( Closeables.close(stream, threwException); } } - return partitionLengths; + return committedPartitions; } - private void writeSingleSpillFileUsingPluggableWriter( + private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter( SpillInfo spillInfo, @Nullable CompressionCodec compressionCodec) throws IOException { assert(pluggableWriteSupport != null); final int numPartitions = partitioner.numPartitions(); + final CommittedPartition[] committedPartitions = new CommittedPartition[numPartitions]; boolean threwException = true; InputStream spillInputStream = new NioBufferedFileInputStream( spillInfo.file, @@ -626,11 +632,8 @@ private void writeSingleSpillFileUsingPluggableWriter( } finally { partitionInputStream.close(); } - CommittedPartition committedPartition = writer.commitPartition(); - if (committedPartition.shuffleLocation().isPresent()) { - shuffleLocations[partition] = committedPartition.shuffleLocation().get(); - } - writeMetrics.incBytesWritten(committedPartition.length()); + committedPartitions[partition] = writer.commitPartition(); + writeMetrics.incBytesWritten(committedPartitions[partition].length()); } threwException = false; } catch (Exception e) { @@ -644,6 +647,7 @@ private void writeSingleSpillFileUsingPluggableWriter( Closeables.close(spillInputStream, threwException); } writeMetrics.decBytesWritten(spillInfo.file.length()); + return committedPartitions; } @Override diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index ac6cfc5e870f..758ceba71398 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -23,7 +23,7 @@ import org.roaringbitmap.RoaringBitmap import scala.collection.mutable import org.apache.spark.SparkEnv -import org.apache.spark.internal.config +import org.apache.spark.internal.{Logging, config} import org.apache.spark.shuffle.api.CommittedPartition import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils @@ -48,7 +48,7 @@ private[spark] sealed trait MapStatus { } -private[spark] object MapStatus { +private[spark] object MapStatus extends Logging { /** * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test @@ -59,16 +59,23 @@ private[spark] object MapStatus { .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = { - val shuffleLocationsArray = committedPartitions.map(a => { - a.shuffleLocation() match { - case empty if empty.isPresent => empty.get() + val shuffleLocationsArray = committedPartitions.map(partition => { + partition match { + case partition if partition != null && partition.shuffleLocation().isPresent + => partition.shuffleLocation().get() case _ => null } }) + val lengthsArray = committedPartitions.map(partition => { + partition match { + case partition if partition != null => partition.length() + case _ => 0 + } + }) if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, committedPartitions.map(_.length()), shuffleLocationsArray) + HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) } else { - new CompressedMapStatus(loc, committedPartitions.map(_.length()), shuffleLocationsArray) + new CompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) } } From a5ee74660e78a280a0a402cde29ecd25d8bac235 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 18 Jan 2019 13:56:49 -0800 Subject: [PATCH 24/30] more cleanup --- .../scala/org/apache/spark/scheduler/MapStatus.scala | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 758ceba71398..65eada767992 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -79,16 +79,6 @@ private[spark] object MapStatus extends Logging { } } - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long], - shuffleLocations: Array[ShuffleLocation]): MapStatus = { - assert(uncompressedSizes.length == shuffleLocations.length) - if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { - HighlyCompressedMapStatus(loc, uncompressedSizes, shuffleLocations) - } else { - new CompressedMapStatus(loc, uncompressedSizes, shuffleLocations) - } - } - def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = { if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { HighlyCompressedMapStatus(loc, uncompressedSizes, Array.empty[ShuffleLocation]) From 6e86ac05304f0eb98e59f0cb565e960ba9adeae3 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 18 Jan 2019 14:17:39 -0800 Subject: [PATCH 25/30] more housekeeping --- .../shuffle/sort/BypassMergeSortShuffleWriter.java | 4 ---- .../spark/shuffle/sort/SortShuffleWriter.scala | 1 - .../spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 12 +----------- .../spark/serializer/KryoSerializerSuite.scala | 5 ++--- .../apache/spark/examples/GroupByShuffleTest.scala | 1 - 5 files changed, 3 insertions(+), 20 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index e33274a5f31d..26b55aa70387 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -122,7 +122,6 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleBlockResolver = shuffleBlockResolver; this.pluggableWriteSupport = pluggableWriteSupport; this.appId = conf.getAppId(); -// this.committedPartitions = new CommittedPartition[numPartitions]; } @Override @@ -182,9 +181,6 @@ public void write(Iterator> records) throws IOException { } } } - logger.info("value of committedPartitions: " + committedPartitions); - logger.info("length of committed partitions:" + committedPartitions.length); - logger.info("length of committed partition value: " + committedPartitions[0].length()); mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), committedPartitions); } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index d3afe31eae7c..98388d80cbe5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -70,7 +70,6 @@ private[spark] class SortShuffleWriter[K, V, C]( val tmp = Utils.tempFileWith(output) try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) - // TODO: fix this, return committed partition val committedPartitions = pluggableWriteSupport.map { writeSupport => sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 164903377378..93ab301a4cb9 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -685,17 +685,7 @@ public CommittedPartition commitPartition() { } int length = partitionBytes.length; partitionSizesInMergedFile[partitionId] = length; - return new CommittedPartition() { - @Override - public long length() { - return length; - } - - @Override - public Optional shuffleLocation() { - return Optional.empty(); - } - }; + return new LocalCommittedPartition(length); } @Override diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 7040c632d3c5..c765be7300b4 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -33,7 +33,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} +import org.apache.spark.storage.{BlockManagerId} import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { @@ -348,8 +348,7 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { val denseBlockSizes = new Array[Long](5000) val sparseBlockSizes = Array[Long](0L, 1L, 0L, 2L) Seq(denseBlockSizes, sparseBlockSizes).foreach { blockSizes => - ser.serialize(HighlyCompressedMapStatus( - BlockManagerId("exec-1", "host", 1234), blockSizes)) + ser.serialize(HighlyCompressedMapStatus(BlockManagerId("exec-1", "host", 1234), blockSizes)) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala index 41d8150acc19..883ac10718df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByShuffleTest.scala @@ -48,7 +48,6 @@ object GroupByShuffleTest { .collect() println(wordCountsWithGroup2.mkString(",")) - Thread.sleep(100000) spark.stop() } From 4a12c93b1540bf2017cd5a2d0b1ad712dbb3e635 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Fri, 18 Jan 2019 14:23:59 -0800 Subject: [PATCH 26/30] sweep sweep --- .../org/apache/spark/scheduler/MapStatus.scala | 17 +++++++---------- .../spark/serializer/KryoSerializerSuite.scala | 2 +- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 65eada767992..21613a5946f6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -23,7 +23,7 @@ import org.roaringbitmap.RoaringBitmap import scala.collection.mutable import org.apache.spark.SparkEnv -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.config import org.apache.spark.shuffle.api.CommittedPartition import org.apache.spark.storage.{BlockManagerId, ShuffleLocation} import org.apache.spark.util.Utils @@ -48,7 +48,7 @@ private[spark] sealed trait MapStatus { } -private[spark] object MapStatus extends Logging { +private[spark] object MapStatus { /** * Min partition number to use [[HighlyCompressedMapStatus]]. A bit ugly here because in test @@ -59,19 +59,16 @@ private[spark] object MapStatus extends Logging { .getOrElse(config.SHUFFLE_MIN_NUM_PARTS_TO_HIGHLY_COMPRESS.defaultValue.get) def apply(loc: BlockManagerId, committedPartitions: Array[CommittedPartition]): MapStatus = { - val shuffleLocationsArray = committedPartitions.map(partition => { - partition match { + val shuffleLocationsArray = committedPartitions.collect { case partition if partition != null && partition.shuffleLocation().isPresent => partition.shuffleLocation().get() case _ => null - } - }) - val lengthsArray = committedPartitions.map(partition => { - partition match { + } + val lengthsArray = committedPartitions.collect { case partition if partition != null => partition.length() case _ => 0 - } - }) + + } if (committedPartitions.length > minPartitionsToUseHighlyCompressMapStatus) { HighlyCompressedMapStatus(loc, lengthsArray, shuffleLocationsArray) } else { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c765be7300b4..2e1950bbc7f5 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -33,7 +33,7 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.storage.{BlockManagerId} +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{ThreadUtils, Utils} class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { From 75ecb6652b01b87972f64a32c15d7a3b9293fabb Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Tue, 22 Jan 2019 10:49:33 -0800 Subject: [PATCH 27/30] Update ShuffleLocation to be part of the read API too --- .../shuffle/api/ShufflePartitionReader.java | 5 ++++- .../external/ExternalShuffleDataIO.java | 5 +---- .../ExternalShufflePartitionReader.java | 21 ++++++++++++------- .../external/ExternalShuffleReadSupport.java | 13 +----------- .../shuffle/BlockStoreShuffleReader.scala | 10 ++++++--- .../apache/spark/SplitFilesShuffleIO.scala | 2 +- 6 files changed, 27 insertions(+), 29 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java index 59eae0a78220..817d213cd8cc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java @@ -17,9 +17,12 @@ package org.apache.spark.shuffle.api; +import org.apache.spark.storage.ShuffleLocation; + import java.io.InputStream; +import java.util.Optional; public interface ShufflePartitionReader { - InputStream fetchPartition(int reduceId); + InputStream fetchPartition(int reduceId, Optional shuffleLocation); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index ac20d13de6f2..2a0a39e4b82e 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -21,7 +21,6 @@ public class ExternalShuffleDataIO implements ShuffleDataIO { private static SecurityManager securityManager; private static String hostname; private static int port; - private static MapOutputTracker mapOutputTracker; public ExternalShuffleDataIO( SparkConf sparkConf) { @@ -37,15 +36,13 @@ public void initialize() { securityManager = env.securityManager(); hostname = blockManager.getRandomShuffleHost(); port = blockManager.getRandomShufflePort(); - mapOutputTracker = env.mapOutputTracker(); // TODO: Register Driver and Executor } @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, context, securityManager.isAuthenticationEnabled(), - securityManager, mapOutputTracker); + conf, context, securityManager.isAuthenticationEnabled(), securityManager); } @Override diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index 8aefac239e97..f02783263925 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -4,14 +4,17 @@ import org.apache.spark.network.client.TransportClientFactory; import org.apache.spark.network.shuffle.protocol.OpenShufflePartition; import org.apache.spark.shuffle.api.ShufflePartitionReader; +import org.apache.spark.storage.ShuffleLocation; import org.apache.spark.util.ByteBufferInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import scala.compat.java8.OptionConverters; import java.io.ByteArrayInputStream; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.Optional; public class ExternalShufflePartitionReader implements ShufflePartitionReader { @@ -19,34 +22,36 @@ public class ExternalShufflePartitionReader implements ShufflePartitionReader { LoggerFactory.getLogger(ExternalShufflePartitionReader.class); private final TransportClientFactory clientFactory; - private final String hostName; - private final int port; private final String appId; private final int shuffleId; private final int mapId; public ExternalShufflePartitionReader( TransportClientFactory clientFactory, - String hostName, - int port, String appId, int shuffleId, int mapId) { this.clientFactory = clientFactory; - this.hostName = hostName; - this.port = port; this.appId = appId; this.shuffleId = shuffleId; this.mapId = mapId; } @Override - public InputStream fetchPartition(int reduceId) { + public InputStream fetchPartition(int reduceId, Optional shuffleLocation) { + assert shuffleLocation.isPresent() && shuffleLocation.get() instanceof ExternalShuffleLocation; + ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) shuffleLocation.get(); + logger.info(String.format("Found external shuffle location on node: %s:%d", + externalShuffleLocation.getShuffleHostname(), + externalShuffleLocation.getShufflePort())); + String hostname = externalShuffleLocation.getShuffleHostname(); + int port = externalShuffleLocation.getShufflePort(); + OpenShufflePartition openMessage = new OpenShufflePartition(appId, shuffleId, mapId, reduceId); TransportClient client = null; try { - client = clientFactory.createUnmanagedClient(hostName, port); + client = clientFactory.createUnmanagedClient(hostname, port); String requestID = String.format( "read-%s-%d-%d-%d", appId, shuffleId, mapId, reduceId); client.setClientId(requestID); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index 9e7ff55f4774..a671b80904ed 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -26,19 +26,16 @@ public class ExternalShuffleReadSupport implements ShuffleReadSupport { private final TransportContext context; private final boolean authEnabled; private final SecretKeyHolder secretKeyHolder; - private final MapOutputTracker mapOutputTracker; public ExternalShuffleReadSupport( TransportConf conf, TransportContext context, boolean authEnabled, - SecretKeyHolder secretKeyHolder, - MapOutputTracker mapOutputTracker) { + SecretKeyHolder secretKeyHolder) { this.conf = conf; this.context = context; this.authEnabled = authEnabled; this.secretKeyHolder = secretKeyHolder; - this.mapOutputTracker = mapOutputTracker; } @Override @@ -48,17 +45,9 @@ public ShufflePartitionReader newPartitionReader(String appId, int shuffleId, in if (authEnabled) { bootstraps.add(new AuthClientBootstrap(conf, appId, secretKeyHolder)); } - Optional maybeShuffleLocation = OptionConverters.toJava(mapOutputTracker.getShuffleLocation(shuffleId, mapId, 0)); - assert maybeShuffleLocation.isPresent(); - ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) maybeShuffleLocation.get(); - logger.info(String.format("Found external shuffle location on node: %s:%d", - externalShuffleLocation.getShuffleHostname(), - externalShuffleLocation.getShufflePort())); TransportClientFactory clientFactory = context.createClientFactory(bootstraps); try { return new ExternalShufflePartitionReader(clientFactory, - externalShuffleLocation.getShuffleHostname(), - externalShuffleLocation.getShufflePort(), appId, shuffleId, mapId); diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 0974c9139274..70c76d594815 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -17,8 +17,10 @@ package org.apache.spark.shuffle +import scala.compat.java8.OptionConverters + import org.apache.spark._ -import org.apache.spark.internal.{config, Logging} +import org.apache.spark.internal.{Logging, config} import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.api.ShuffleReadSupport import org.apache.spark.storage._ @@ -54,10 +56,12 @@ private[spark] class BlockStoreShuffleReader[K, C]( blockIds.map { case blockId@ShuffleBlockId(_, _, reduceId) => (blockId, serializerManager.wrapStream(blockId, - reader.fetchPartition(reduceId))) + reader.fetchPartition(reduceId, OptionConverters.toJava( + mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId))))) case dataBlockId@ShuffleDataBlockId(_, _, reduceId) => (dataBlockId, serializerManager.wrapStream(dataBlockId, - reader.fetchPartition(reduceId))) + reader.fetchPartition(reduceId, OptionConverters.toJava( + mapOutputTracker.getShuffleLocation(handle.shuffleId, mapId, reduceId))))) case invalid => throw new IllegalArgumentException(s"Invalid block id $invalid") } diff --git a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala index 097d1e406dc0..579fc9a45ba9 100644 --- a/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala +++ b/core/src/test/scala/org/apache/spark/SplitFilesShuffleIO.scala @@ -33,7 +33,7 @@ class SplitFilesShuffleIO(conf: SparkConf) extends ShuffleDataIO { override def initialize(): Unit = {} override def readSupport(): ShuffleReadSupport = (appId: String, shuffleId: Int, mapId: Int) => { - reduceId: Int => { + (reduceId: Int, shuffleLocation: Optional[ShuffleLocation]) => { new FileInputStream(resolvePartitionFile(appId, shuffleId, mapId, reduceId)) } } From af5897833dd4fd760a7c907b1c31ce842b658954 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 22 Jan 2019 14:38:59 -0800 Subject: [PATCH 28/30] Changes to ByteBuffer and serialization logic - add exceptions to API - fix write side: use serializer manager, fix usage of ByteBuffer. - misc compilation / style fixes to get things to build. --- .../spark/shuffle/api/ShuffleDataIO.java | 8 +++-- .../shuffle/api/ShuffleMapOutputWriter.java | 8 +++-- .../shuffle/api/ShufflePartitionReader.java | 8 +++-- .../shuffle/api/ShufflePartitionWriter.java | 7 +++-- .../spark/shuffle/api/ShuffleReadSupport.java | 5 +++- .../shuffle/api/ShuffleWriteSupport.java | 5 +++- .../external/ExternalShuffleReadSupport.java | 3 +- .../shuffle/BlockStoreShuffleReader.scala | 2 +- .../ShufflePartitionWriterOutputStream.scala | 30 +++++++++---------- .../shuffle/sort/SortShuffleWriter.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 5 ++-- .../ShufflePartitionObjectWriter.scala | 11 +++---- .../util/collection/ExternalSorter.scala | 13 +++++--- ...ernetesShuffleServiceAddressProvider.scala | 6 ---- .../cluster/mesos/MesosClusterManager.scala | 3 +- .../cluster/YarnClusterManager.scala | 4 ++- 16 files changed, 68 insertions(+), 52 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java index b091e231f2cd..19cd94712a8a 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -16,11 +16,13 @@ */ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleDataIO { - void initialize(); + void initialize() throws IOException; - ShuffleReadSupport readSupport(); + ShuffleReadSupport readSupport() throws IOException; - ShuffleWriteSupport writeSupport(); + ShuffleWriteSupport writeSupport() throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java index 06415dba72d3..becb9413a8f4 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -17,11 +17,13 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleMapOutputWriter { - ShufflePartitionWriter newPartitionWriter(int partitionId); + ShufflePartitionWriter newPartitionWriter(int partitionId) throws IOException; - void commitAllPartitions(); + void commitAllPartitions() throws IOException; - void abort(Exception exception); + void abort(Exception exception) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java index 817d213cd8cc..46d169972498 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionReader.java @@ -17,12 +17,14 @@ package org.apache.spark.shuffle.api; -import org.apache.spark.storage.ShuffleLocation; - import java.io.InputStream; +import java.io.IOException; import java.util.Optional; +import org.apache.spark.storage.ShuffleLocation; + public interface ShufflePartitionReader { - InputStream fetchPartition(int reduceId, Optional shuffleLocation); + InputStream fetchPartition(int reduceId, Optional shuffleLocation) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java index bdc0fd45474f..e7cc6dd913d1 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -17,6 +17,7 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; import java.io.OutputStream; /** @@ -27,18 +28,18 @@ public interface ShufflePartitionWriter { /** * Return a stream that should persist the bytes for this partition. */ - OutputStream openPartitionStream(); + OutputStream openPartitionStream() throws IOException; /** * Indicate that the partition was written successfully and there are no more incoming bytes. * Returns a {@link CommittedPartition} indicating information about that written partition. */ - CommittedPartition commitPartition(); + CommittedPartition commitPartition() throws IOException; /** * Indicate that the write has failed for some reason and the implementation can handle the * failure reason. After this method is called, this writer will be discarded; it's expected that * the implementation will close any underlying resources. */ - void abort(Exception failureReason); + void abort(Exception failureReason) throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java index b1be7c1de98a..ebe8fd12dccd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleReadSupport.java @@ -17,8 +17,11 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleReadSupport { - ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId); + ShufflePartitionReader newPartitionReader(String appId, int shuffleId, int mapId) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java index 2f61dbaa17c6..f88555f8a1bd 100644 --- a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleWriteSupport.java @@ -17,7 +17,10 @@ package org.apache.spark.shuffle.api; +import java.io.IOException; + public interface ShuffleWriteSupport { - ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId); + ShuffleMapOutputWriter newMapOutputWriter(String appId, int shuffleId, int mapId) + throws IOException; } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index a671b80904ed..a9eac0443fdc 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -1,5 +1,7 @@ package org.apache.spark.shuffle.external; +import scala.compat.java8.OptionConverters; + import com.google.common.collect.Lists; import org.apache.spark.MapOutputTracker; import org.apache.spark.network.TransportContext; @@ -13,7 +15,6 @@ import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.compat.java8.OptionConverters; import java.util.List; import java.util.Optional; diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 70c76d594815..caeecedc5d36 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle import scala.compat.java8.OptionConverters import org.apache.spark._ -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.api.ShuffleReadSupport import org.apache.spark.storage._ diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala index 2eed51962181..8a776281041d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShufflePartitionWriterOutputStream.scala @@ -20,38 +20,36 @@ package org.apache.spark.shuffle import java.io.{InputStream, OutputStream} import java.nio.ByteBuffer -import org.apache.spark.network.util.LimitedInputStream +import org.apache.spark.serializer.SerializerManager import org.apache.spark.shuffle.api.ShufflePartitionWriter +import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.{ByteBufferInputStream, Utils} class ShufflePartitionWriterOutputStream( - partitionWriter: ShufflePartitionWriter, buffer: ByteBuffer, bufferSize: Int) - extends OutputStream { + blockId: ShuffleBlockId, + partitionWriter: ShufflePartitionWriter, + buffer: ByteBuffer, + serializerManager: SerializerManager) + extends OutputStream { - private var currentChunkSize = 0 - private val bufferForRead = buffer.asReadOnlyBuffer() private var underlyingOutputStream: OutputStream = _ override def write(b: Int): Unit = { - buffer.putInt(b) - currentChunkSize += 1 - if (currentChunkSize == bufferSize) { + buffer.put(b.asInstanceOf[Byte]) + if (buffer.remaining() == 0) { pushBufferedBytesToUnderlyingOutput() } } private def pushBufferedBytesToUnderlyingOutput(): Unit = { - bufferForRead.reset() - var bufferInputStream: InputStream = new ByteBufferInputStream(bufferForRead) - if (currentChunkSize < bufferSize) { - bufferInputStream = new LimitedInputStream(bufferInputStream, currentChunkSize) - } + buffer.flip() + var bufferInputStream: InputStream = new ByteBufferInputStream(buffer) if (underlyingOutputStream == null) { - underlyingOutputStream = partitionWriter.openPartitionStream() + underlyingOutputStream = serializerManager.wrapStream(blockId, + partitionWriter.openPartitionStream()) } Utils.copyStream(bufferInputStream, underlyingOutputStream, false, false) - buffer.reset() - currentChunkSize = 0 + buffer.clear() } override def flush(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 98388d80cbe5..b6ab2f354e81 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -71,7 +71,7 @@ private[spark] class SortShuffleWriter[K, V, C]( try { val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID) val committedPartitions = pluggableWriteSupport.map { writeSupport => - sorter.writePartitionedToExternalShuffleWriteSupport(mapId, dep.shuffleId, writeSupport) + sorter.writePartitionedToExternalShuffleWriteSupport(blockId, writeSupport) }.getOrElse(sorter.writePartitionedFile(blockId, tmp)) if (pluggableWriteSupport.isEmpty) { shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1575b076d3fa..fb1ed02c857a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import java.io._ -import java.lang.ref.{WeakReference, ReferenceQueue => JReferenceQueue} +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels import java.util.Collections @@ -31,11 +31,12 @@ import scala.concurrent.duration._ import scala.reflect.ClassTag import scala.util.Random import scala.util.control.NonFatal + import com.codahale.metrics.{MetricRegistry, MetricSet} import org.apache.spark._ import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.{Logging, config} +import org.apache.spark.internal.{config, Logging} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.metrics.source.Source import org.apache.spark.network._ diff --git a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala index baaee46f8123..b2263a51051a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShufflePartitionObjectWriter.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.nio.ByteBuffer -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.ShufflePartitionWriterOutputStream import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, ShufflePartitionWriter} @@ -30,10 +30,12 @@ import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleMapOutputWriter, * left to the implementation of the underlying implementation of the writer plugin. */ private[spark] class ShufflePartitionObjectWriter( + blockId: ShuffleBlockId, bufferSize: Int, serializerInstance: SerializerInstance, + serializerManager: SerializerManager, mapOutputWriter: ShuffleMapOutputWriter) - extends PairsWriter { + extends PairsWriter { // Reused buffer. Experiments should be done with off-heap at some point. private val buffer = ByteBuffer.allocate(bufferSize) @@ -44,10 +46,9 @@ private[spark] class ShufflePartitionObjectWriter( def startNewPartition(partitionId: Int): Unit = { require(buffer.position() == 0, "Buffer was not flushed to the underlying output on the previous partition.") - buffer.reset() currentWriter = mapOutputWriter.newPartitionWriter(partitionId) val currentWriterStream = new ShufflePartitionWriterOutputStream( - currentWriter, buffer, bufferSize) + blockId, currentWriter, buffer, serializerManager) objectOutputStream = serializerInstance.serializeStream(currentWriterStream) } @@ -56,7 +57,7 @@ private[spark] class ShufflePartitionObjectWriter( require(currentWriter != null, "Cannot commit a partition that has not been started.") objectOutputStream.close() val committedPartition = currentWriter.commitPartition() - buffer.reset() + buffer.clear() currentWriter = null objectOutputStream = null committedPartition diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 69077c644dc7..70c36c40865b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -24,13 +24,13 @@ import com.google.common.io.ByteStreams import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{util, _} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging import org.apache.spark.serializer._ import org.apache.spark.shuffle.api.{CommittedPartition, ShuffleWriteSupport} import org.apache.spark.shuffle.sort.LocalCommittedPartition -import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter, PairsWriter, ShuffleLocation, ShufflePartitionObjectWriter} -import org.apache.spark.{util, _} +import org.apache.spark.storage._ /** * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner @@ -727,14 +727,18 @@ private[spark] class ExternalSorter[K, V, C]( * Write all partitions to some backend that is pluggable. */ def writePartitionedToExternalShuffleWriteSupport( - mapId: Int, shuffleId: Int, writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = { + blockId: ShuffleBlockId, + writeSupport: ShuffleWriteSupport): Array[CommittedPartition] = { // Track location of each range in the output file val committedPartitions = new Array[CommittedPartition](numPartitions) - val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, shuffleId, mapId) + val mapOutputWriter = writeSupport.newMapOutputWriter(conf.getAppId, blockId.shuffleId, + blockId.mapId) val writer = new ShufflePartitionObjectWriter( + blockId, Math.min(serializerBatchSize, Integer.MAX_VALUE).toInt, serInstance, + serializerManager, mapOutputWriter) try { @@ -781,6 +785,7 @@ private[spark] class ExternalSorter[K, V, C]( mapOutputWriter.commitAllPartitions() } catch { case e: Exception => + logError("Error writing shuffle data.", e) util.Utils.tryLogNonFatalError { writer.abortCurrentPartition(e) mapOutputWriter.abort(e) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala index 63074f6f14d7..420a82bd7d8e 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -139,10 +139,4 @@ class KubernetesShuffleServiceAddressProvider( override def onClose(e: KubernetesClientException): Unit = {} } - - private implicit def toRunnable(func: () => Unit): Runnable = { - new Runnable { - override def run(): Unit = func() - } - } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala index 48ef8df37ecc..a69b0d305035 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterManager.scala @@ -60,8 +60,9 @@ private[spark] class MesosClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) + } - override def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = { DefaultShuffleServiceAddressProvider } } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala index b2a4fd42c60f..8e83d49d2332 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterManager.scala @@ -54,6 +54,8 @@ private[spark] class YarnClusterManager extends ExternalClusterManager { override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { scheduler.asInstanceOf[TaskSchedulerImpl].initialize(backend) } - def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = + + def createShuffleServiceAddressProvider(): ShuffleServiceAddressProvider = { DefaultShuffleServiceAddressProvider + } } From 1381f555f2e04232bb218ed290d7ec8d6fef40cf Mon Sep 17 00:00:00 2001 From: Ilan Filonenko Date: Tue, 22 Jan 2019 17:49:49 -0800 Subject: [PATCH 29/30] resolve some comments regarding BlockManager and slight style --- .../shuffle/ExternalShuffleBlockHandler.java | 6 -- .../shuffle/ExternalShuffleBlockResolver.java | 6 +- .../shuffle/FileWriterStreamCallback.java | 49 +++++++------- .../external/ExternalShuffleDataIO.java | 66 ++++++++++++++----- .../external/ExternalShuffleLocation.java | 2 - .../ExternalShufflePartitionReader.java | 7 +- .../ExternalShufflePartitionWriter.java | 13 ++-- .../external/ExternalShuffleReadSupport.java | 5 -- .../sort/BypassMergeSortShuffleWriter.java | 4 +- .../shuffle/sort/UnsafeShuffleWriter.java | 11 ++-- .../org/apache/spark/MapOutputTracker.scala | 15 +++-- .../org/apache/spark/executor/Executor.scala | 2 +- .../apache/spark/storage/BlockManager.scala | 48 ++------------ .../BypassMergeSortShuffleWriterSuite.scala | 3 +- .../KubernetesExternalShuffleService.scala | 9 +-- ...ernetesShuffleServiceAddressProvider.scala | 1 + 16 files changed, 119 insertions(+), 128 deletions(-) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 59e2229f2db1..75172957746a 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -198,10 +198,6 @@ private class ShuffleMetrics implements MetricSet { private final Timer registerExecutorRequestLatencyMillis = new Timer(); // Block transfer rate in byte per second private final Meter blockTransferRateBytes = new Meter(); - // Partition upload latency in ms - private final Timer uploadPartitionkStreamMillis = new Timer(); - // Partition read latency in ms - private final Timer openPartitionMillis = new Timer(); private ShuffleMetrics() { allMetrics = new HashMap<>(); @@ -210,8 +206,6 @@ private ShuffleMetrics() { allMetrics.put("blockTransferRateBytes", blockTransferRateBytes); allMetrics.put("registeredExecutorsSize", (Gauge) () -> blockManager.getRegisteredExecutorsSize()); - allMetrics.put("uploadPartitionkStreamMillis", uploadPartitionkStreamMillis); - allMetrics.put("openPartitionMillis", openPartitionMillis); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 757b8f7b545b..7b9c75cd0779 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -19,7 +19,6 @@ import java.io.*; import java.nio.charset.StandardCharsets; -import java.nio.file.Files; import java.util.regex.Pattern; import java.util.*; import java.util.concurrent.ConcurrentMap; @@ -94,7 +93,7 @@ public class ExternalShuffleBlockResolver { private final List knownManagers = Arrays.asList( "org.apache.spark.shuffle.sort.SortShuffleManager", - "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); + "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager"); public ExternalShuffleBlockResolver(TransportConf conf, File registeredExecutorFile) throws IOException { @@ -132,7 +131,6 @@ public int weigh(File file, ShuffleIndexInformation indexInfo) { } else { executors = Maps.newConcurrentMap(); } - this.directoryCleaner = directoryCleaner; } @@ -181,8 +179,6 @@ public ManagedBuffer getBlockData( return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); } - - /** * Removes our metadata of all executors registered for the given application, and optionally * also deletes the local directories associated with the executors of that application in a diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java index 1f44ae8b3c78..9b3736a9ecf1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/FileWriterStreamCallback.java @@ -4,12 +4,11 @@ import org.slf4j.LoggerFactory; import java.io.File; -import java.io.FileNotFoundException; -import java.io.FileOutputStream; import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.channels.Channels; +import java.nio.channels.FileChannel; import java.nio.channels.WritableByteChannel; +import java.nio.file.StandardOpenOption; import org.apache.spark.network.client.StreamCallbackWithID; @@ -23,7 +22,7 @@ public enum FileType { private final String typeString; - FileType(String typeString) { + FileType(String typeString) { this.typeString = typeString; } @@ -55,13 +54,13 @@ public FileWriterStreamCallback( public void open() { logger.info( - "Opening {} for remote writing. File type: {}", file.getAbsolutePath(), fileType); + "Opening {} for remote writing. File type: {}", file.getAbsolutePath(), fileType); if (fileOutputChannel != null) { throw new IllegalStateException( - String.format( - "File %s for is already open for writing (type: %s).", - file.getAbsolutePath(), - fileType)); + String.format( + "File %s for is already open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); } if (!file.exists()) { try { @@ -88,8 +87,8 @@ public void open() { } try { // TODO encryption - fileOutputChannel = Channels.newChannel(new FileOutputStream(file)); - } catch (FileNotFoundException e) { + fileOutputChannel = FileChannel.open(file.toPath(), StandardOpenOption.APPEND); + } catch (IOException e) { throw new RuntimeException( String.format( "Failed to find file for writing at %s (type: %s).", @@ -102,10 +101,10 @@ public void open() { @Override public String getID() { return String.format("%s-%d-%d-%s", - appId, - shuffleId, - mapId, - fileType); + appId, + shuffleId, + mapId, + fileType); } @Override @@ -119,23 +118,23 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { logger.info( - "Finished writing {}. File type: {}", file.getAbsolutePath(), fileType); + "Finished writing {}. File type: {}", file.getAbsolutePath(), fileType); fileOutputChannel.close(); } @Override public void onFailure(String streamId, Throwable cause) throws IOException { logger.warn("Failed to write shuffle file at {} (type: %s).", - file.getAbsolutePath(), - fileType, - cause); + file.getAbsolutePath(), + fileType, + cause); fileOutputChannel.close(); // TODO delete parent dirs too if (!file.delete()) { logger.warn( - "Failed to delete incomplete remote shuffle file at %s (type: %s)", - file.getAbsolutePath(), - fileType); + "Failed to delete incomplete remote shuffle file at %s (type: %s)", + file.getAbsolutePath(), + fileType); } } @@ -143,9 +142,9 @@ private void verifyShuffleFileOpenForWriting() { if (fileOutputChannel == null) { throw new RuntimeException( String.format( - "Shuffle file at %s not open for writing (type: %s).", - file.getAbsolutePath(), - fileType)); + "Shuffle file at %s not open for writing (type: %s).", + file.getAbsolutePath(), + fileType)); } } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java index 2a0a39e4b82e..10e4093759c3 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleDataIO.java @@ -1,8 +1,8 @@ package org.apache.spark.shuffle.external; -import org.apache.spark.MapOutputTracker; -import org.apache.spark.SparkConf; -import org.apache.spark.SparkEnv; +import org.apache.spark.*; +import org.apache.spark.SecurityManager; +import org.apache.spark.internal.config.package$; import org.apache.spark.network.TransportContext; import org.apache.spark.network.netty.SparkTransportConf; import org.apache.spark.network.server.NoOpRpcHandler; @@ -10,45 +10,77 @@ import org.apache.spark.shuffle.api.ShuffleDataIO; import org.apache.spark.shuffle.api.ShuffleReadSupport; import org.apache.spark.shuffle.api.ShuffleWriteSupport; -import org.apache.spark.SecurityManager; +import org.apache.spark.network.shuffle.k8s.KubernetesExternalShuffleClient; + import org.apache.spark.storage.BlockManager; +import scala.Tuple2; + +import java.util.List; +import java.util.Random; public class ExternalShuffleDataIO implements ShuffleDataIO { - private final TransportConf conf; + private final SparkConf conf; + private final TransportConf transportConf; private final TransportContext context; - private static BlockManager blockManager; + private static MapOutputTracker mapOutputTracker; private static SecurityManager securityManager; - private static String hostname; - private static int port; + private static List> hostPorts; + private static Boolean isDriver; + private static KubernetesExternalShuffleClient shuffleClient; public ExternalShuffleDataIO( SparkConf sparkConf) { - this.conf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); + this.conf = sparkConf; + // TODO: Grab numUsableCores + this.transportConf = SparkTransportConf.fromSparkConf(sparkConf, "shuffle", 1); // Close idle connections - this.context = new TransportContext(conf, new NoOpRpcHandler(), true, true); + this.context = new TransportContext(transportConf, new NoOpRpcHandler(), true, true); } @Override public void initialize() { SparkEnv env = SparkEnv.get(); - blockManager = env.blockManager(); + mapOutputTracker = env.mapOutputTracker(); securityManager = env.securityManager(); - hostname = blockManager.getRandomShuffleHost(); - port = blockManager.getRandomShufflePort(); - // TODO: Register Driver and Executor + isDriver = env.blockManager().blockManagerId().isDriver(); + hostPorts = mapOutputTracker.getRemoteShuffleServiceAddress(); + if (isDriver) { + shuffleClient = new KubernetesExternalShuffleClient(transportConf, securityManager, + securityManager.isAuthenticationEnabled(), conf.getTimeAsMs( + package$.MODULE$.SHUFFLE_REGISTRATION_TIMEOUT().key(), "5000ms")); + shuffleClient.init(conf.getAppId()); + for (Tuple2 hp : hostPorts) { + try { + shuffleClient.registerDriverWithShuffleService( + hp._1, hp._2, + conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", + conf.getTimeAsSeconds("spark.network.timeout", "120s") + "s"), + conf.getTimeAsSeconds( + package$.MODULE$.EXECUTOR_HEARTBEAT_INTERVAL().key(), "10s")); + } catch (Exception e) { + throw new RuntimeException("Unable to register driver with ESS", e); + } + } + BlockManager.ShuffleMetricsSource metricSource = + new BlockManager.ShuffleMetricsSource( + "RemoteShuffleService", shuffleClient.shuffleMetrics()); + env.metricsSystem().registerSource(metricSource); + } } @Override public ShuffleReadSupport readSupport() { return new ExternalShuffleReadSupport( - conf, context, securityManager.isAuthenticationEnabled(), securityManager); + transportConf, context, securityManager.isAuthenticationEnabled(), securityManager); } @Override public ShuffleWriteSupport writeSupport() { + int rnd = new Random().nextInt(hostPorts.size()); + Tuple2 hostPort = hostPorts.get(rnd); return new ExternalShuffleWriteSupport( - conf, context, securityManager.isAuthenticationEnabled(), - securityManager, hostname, port); + transportConf, context, securityManager.isAuthenticationEnabled(), + securityManager, hostPort._1, hostPort._2); } } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java index 20ae8d376050..c1178da2411f 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleLocation.java @@ -1,7 +1,5 @@ package org.apache.spark.shuffle.external; -import org.apache.hadoop.mapreduce.task.reduce.Shuffle; -import org.apache.spark.network.protocol.Encoders; import org.apache.spark.storage.ShuffleLocation; import java.io.*; diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java index f02783263925..10f1b7100847 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionReader.java @@ -8,7 +8,6 @@ import org.apache.spark.util.ByteBufferInputStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import scala.compat.java8.OptionConverters; import java.io.ByteArrayInputStream; import java.io.InputStream; @@ -39,8 +38,10 @@ public ExternalShufflePartitionReader( @Override public InputStream fetchPartition(int reduceId, Optional shuffleLocation) { - assert shuffleLocation.isPresent() && shuffleLocation.get() instanceof ExternalShuffleLocation; - ExternalShuffleLocation externalShuffleLocation = (ExternalShuffleLocation) shuffleLocation.get(); + assert shuffleLocation.isPresent() && + shuffleLocation.get() instanceof ExternalShuffleLocation; + ExternalShuffleLocation externalShuffleLocation = + (ExternalShuffleLocation) shuffleLocation.get(); logger.info(String.format("Found external shuffle location on node: %s:%d", externalShuffleLocation.getShuffleHostname(), externalShuffleLocation.getShufflePort())); diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java index 89bfe4407e5a..edf046a32ffe 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShufflePartitionWriter.java @@ -8,14 +8,12 @@ import org.apache.spark.network.shuffle.protocol.UploadShufflePartitionStream; import org.apache.spark.shuffle.api.CommittedPartition; import org.apache.spark.shuffle.api.ShufflePartitionWriter; -import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.*; import java.nio.ByteBuffer; import java.util.Arrays; -import java.util.Optional; public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { @@ -31,7 +29,7 @@ public class ExternalShufflePartitionWriter implements ShufflePartitionWriter { private final int partitionId; private long totalLength = 0; - private final ByteArrayOutputStream partitionBuffer = new ByteArrayOutputStream(); + private ByteArrayOutputStream partitionBuffer; public ExternalShufflePartitionWriter( TransportClientFactory clientFactory, @@ -48,6 +46,8 @@ public ExternalShufflePartitionWriter( this.shuffleId = shuffleId; this.mapId = mapId; this.partitionId = partitionId; + // TODO: Set buffer size + this.partitionBuffer = new ByteArrayOutputStream(); } @Override @@ -84,6 +84,7 @@ public void onFailure(Throwable e) { logger.info("Size: " + size); } catch (Exception e) { if (client != null) { + partitionBuffer = null; client.close(); } logger.error("Encountered error while attempting to upload partition to ESS", e); @@ -91,13 +92,15 @@ public void onFailure(Throwable e) { } finally { logger.info("Successfully sent partition to ESS"); } - return new ExternalCommittedPartition(totalLength, new ExternalShuffleLocation(hostName, port)); + return new ExternalCommittedPartition( + totalLength, new ExternalShuffleLocation(hostName, port)); } @Override public void abort(Exception failureReason) { try { - this.partitionBuffer.close(); + partitionBuffer.close(); + partitionBuffer = null; } catch(IOException e) { logger.error("Failed to close streams after failing to upload partition", e); } diff --git a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java index a9eac0443fdc..0bde10a77766 100644 --- a/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java +++ b/core/src/main/java/org/apache/spark/shuffle/external/ExternalShuffleReadSupport.java @@ -1,9 +1,6 @@ package org.apache.spark.shuffle.external; -import scala.compat.java8.OptionConverters; - import com.google.common.collect.Lists; -import org.apache.spark.MapOutputTracker; import org.apache.spark.network.TransportContext; import org.apache.spark.network.client.TransportClientBootstrap; import org.apache.spark.network.client.TransportClientFactory; @@ -12,12 +9,10 @@ import org.apache.spark.network.util.TransportConf; import org.apache.spark.shuffle.api.ShufflePartitionReader; import org.apache.spark.shuffle.api.ShuffleReadSupport; -import org.apache.spark.storage.ShuffleLocation; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.List; -import java.util.Optional; public class ExternalShuffleReadSupport implements ShuffleReadSupport { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 26b55aa70387..b21d37401c05 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -46,7 +46,6 @@ import javax.annotation.Nullable; import java.io.*; import java.util.Arrays; -import java.util.stream.Collectors; /** * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path @@ -171,6 +170,7 @@ public void write(Iterator> records) throws IOException { try { committedPartitions = combineAndWritePartitions(tmp); logger.info("Successfully wrote partitions without shuffle"); + // TODO: Investigate when commitedPartitions is null or returns empty shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, Arrays.stream(committedPartitions).mapToLong(p -> p.length()).toArray(), @@ -213,7 +213,7 @@ private CommittedPartition[] combineAndWritePartitions(File outputFile) throws I boolean copyThrewException = true; try { partitions[i] = - new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled)); + new LocalCommittedPartition(Utils.copyStream(in, out, false, transferToEnabled)); copyThrewException = false; } finally { Closeables.close(in, copyThrewException); diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index ef086e21b04d..9eddfc924404 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -23,10 +23,8 @@ import java.util.Arrays; import java.util.Iterator; import java.util.stream.Collectors; -import java.util.stream.Stream; import org.apache.spark.shuffle.api.CommittedPartition; -import org.apache.spark.storage.ShuffleLocation; import scala.Option; import scala.Product2; import scala.collection.JavaConverters; @@ -339,14 +337,17 @@ private CommittedPartition[] mergeSpills(SpillInfo[] spills, File outputFile) th // that doesn't need to interpret the spilled bytes. if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); - committedPartitions = toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile)); + committedPartitions = + toLocalCommittedPartition(mergeSpillsWithTransferTo(spills, outputFile)); } else { logger.debug("Using fileStream-based fast merge"); - committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, null)); + committedPartitions = toLocalCommittedPartition( + mergeSpillsWithFileStream(spills, outputFile, null)); } } else { logger.debug("Using slow merge"); - committedPartitions = toLocalCommittedPartition(mergeSpillsWithFileStream(spills, outputFile, compressionCodec)); + committedPartitions = toLocalCommittedPartition( + mergeSpillsWithFileStream(spills, outputFile, compressionCodec)); } // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has // in-memory records, we write out the in-memory records to a file but do not count that diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index b340ffbbc43d..214ff3ee18fe 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -22,7 +22,6 @@ import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolE import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ -import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet, ListBuffer, Map} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration @@ -237,7 +236,7 @@ private[spark] class MapOutputTrackerMasterEndpoint( stop() case GetRemoteShuffleServiceAddresses => - context.reply(tracker.getRemoteShuffleServiceAddresses) + context.reply(tracker.getRemoteShuffleServiceAddress()) } } @@ -305,6 +304,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging def getShuffleLocation(shuffleId: Int, mapId: Int, reduceId: Int) : Option[ShuffleLocation] + def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -653,8 +654,9 @@ private[spark] class MapOutputTrackerMaster( } } - def getRemoteShuffleServiceAddresses: List[(String, Int)] = - shuffleServiceAddressProvider.getShuffleServiceAddresses() + override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] = + shuffleServiceAddressProvider + .getShuffleServiceAddresses().map { case (h, p) => (h, new Integer(p))}.asJava // Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result. // This method is only called in local-mode. @@ -807,6 +809,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr case None => Option.empty } } + + override def getRemoteShuffleServiceAddress(): java.util.List[(String, Integer)] = { + trackerEndpoint + .askSync[java.util.List[(String, Integer)]](GetRemoteShuffleServiceAddresses) + } } private[spark] object MapOutputTracker extends Logging { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index ae5b1a3c6946..b6f4fc1921bf 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -118,7 +118,7 @@ private[spark] class Executor( env.blockManager.initialize(conf.getAppId) env.metricsSystem.registerSource(executorSource) env.metricsSystem.registerSource(env.blockManager.shuffleMetricsSource) - // Initialize the ShuffleDataIo + // Initialize the ShuffleDataIO env.shuffleDataIO.foreach(_.initialize()) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index fb1ed02c857a..f604f20ee722 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -44,7 +44,6 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle._ -import org.apache.spark.network.shuffle.k8s.KubernetesExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.network.util.TransportConf import org.apache.spark.rpc.RpcEnv @@ -131,14 +130,8 @@ private[spark] class BlockManager( numUsableCores: Int) extends BlockDataManager with BlockEvictionHandler with Logging { - private[spark] val externalNonK8sShuffleService = - conf.get(config.SHUFFLE_SERVICE_ENABLED) - - private[spark] val externalk8sShuffleServiceEnabled = - conf.get(config.K8S_SHUFFLE_SERVICE_ENABLED) - private[spark] val externalShuffleServiceEnabled = - externalNonK8sShuffleService || externalk8sShuffleServiceEnabled + conf.get(config.SHUFFLE_SERVICE_ENABLED) private val remoteReadNioBufferConversion = conf.getBoolean("spark.network.remoteReadNioBufferConversion", false) @@ -184,9 +177,6 @@ private[spark] class BlockManager( } } - private var remoteShuffleServiceAddress: List[(String, Int)] = List() - private var randomShuffleServiceAddress: (String, Int) = null - var blockManagerId: BlockManagerId = _ // Address of the server that serves this executor's shuffle files. This is either an external @@ -195,11 +185,7 @@ private[spark] class BlockManager( // Client to read other executors' shuffle files. This is either an external service, or just the // standard BlockTransferService to directly connect to other Executors. - private[spark] val shuffleClient = if (externalk8sShuffleServiceEnabled) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - new KubernetesExternalShuffleClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) - } else if (externalNonK8sShuffleService) { + private[spark] val shuffleClient = if (externalShuffleServiceEnabled) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)) @@ -267,35 +253,14 @@ private[spark] class BlockManager( blockManagerId = if (idFromMaster != null) idFromMaster else id - if (externalk8sShuffleServiceEnabled) { - remoteShuffleServiceAddress = mapOutputTracker - .trackerEndpoint - .askSync[List[(String, Int)]](GetRemoteShuffleServiceAddresses) - } - - shuffleServerId = if (externalk8sShuffleServiceEnabled) { - // TODO: Investigate better methods of load balancing - // note: might break if re-initialized - randomShuffleServiceAddress = remoteShuffleServiceAddress.head - BlockManagerId(executorId, randomShuffleServiceAddress._1, randomShuffleServiceAddress._2) - } else if (externalNonK8sShuffleService) { + shuffleServerId = if (externalShuffleServiceEnabled) { logInfo(s"external shuffle service port = $externalShuffleServicePort") BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort) } else { blockManagerId } - if (externalk8sShuffleServiceEnabled && blockManagerId.isDriver) { - // Register Drivers' configuration with the k8s shuffle services - remoteShuffleServiceAddress.foreach { ssId => - shuffleClient.asInstanceOf[KubernetesExternalShuffleClient] - .registerDriverWithShuffleService( - ssId._1, ssId._2, - conf.getTimeAsMs("spark.storage.blockManagerSlaveTimeoutMs", - s"${conf.getTimeAsSeconds("spark.network.timeout", "120s")}s"), - conf.get(config.EXECUTOR_HEARTBEAT_INTERVAL)) - } - } else if (externalNonK8sShuffleService && !blockManagerId.isDriver) { + if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { // Register Executors' configuration with the local shuffle service, if one should exist. registerWithExternalShuffleServer() } @@ -362,9 +327,6 @@ private[spark] class BlockManager( } } - private[spark] def getRandomShuffleHost: String = randomShuffleServiceAddress._1 - private[spark] def getRandomShufflePort: Int = randomShuffleServiceAddress._2 - /** * Re-register with the master and report all blocks to it. This will be called by the heart beat * thread if our heartbeat to the block manager indicates that we were not registered. @@ -1702,7 +1664,7 @@ private[spark] object BlockManager { blockManagers.toMap } - private class ShuffleMetricsSource( + class ShuffleMetricsSource( override val sourceName: String, metricSet: MetricSet) extends Source { diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 75202b57833a..8919692a8f76 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -49,7 +49,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ - private val conf: SparkConf = new SparkConf(loadDefaults = false).set("spark.app.id", "spark-app-id") + private val conf: SparkConf = + new SparkConf(loadDefaults = false).set("spark.app.id", "spark-app-id") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala index b9d69f1bc69f..07dbffacc31f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesExternalShuffleService.scala @@ -282,6 +282,7 @@ private[spark] class KubernetesExternalShuffleBlockHandler( allMetrics.put("openBlockRequestLatencyMillis", _openBlockRequestLatencyMillis) allMetrics.put("registerDriverRequestLatencyMillis", _registerDriverRequestLatencyMillis) allMetrics.put("blockTransferRateBytes", _blockTransferRateBytes) + override def getMetrics: util.Map[String, Metric] = allMetrics } @@ -297,10 +298,10 @@ private[spark] class KubernetesExternalShuffleService( extends ExternalShuffleService(conf, securityManager) { protected override def newShuffleBlockHandler( - conf: TransportConf): ExternalShuffleBlockHandler = { - val cleanerIntervals = this.conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) - val indexCacheSize = this.conf.get("spark.shuffle.service.index.cache.size", "100m") - new KubernetesExternalShuffleBlockHandler(conf, cleanerIntervals, indexCacheSize) + transportConf: TransportConf): ExternalShuffleBlockHandler = { + val cleanerIntervals = conf.get(KUBERNETES_REMOTE_SHUFFLE_SERVICE_CLEANUP_INTERVAL) + val indexCacheSize = conf.get("spark.shuffle.service.index.cache.size", "100m") + new KubernetesExternalShuffleBlockHandler(transportConf, cleanerIntervals, indexCacheSize) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala index 420a82bd7d8e..f0bdb3521665 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/shuffle/k8s/KubernetesShuffleServiceAddressProvider.scala @@ -92,6 +92,7 @@ class KubernetesShuffleServiceAddressProvider( } } + // TODO: Re-register with found shuffle service instances private def pollForPods(): Unit = { val writeLock = podsUpdateLock.writeLock() writeLock.lock() From 7c0fa1d4e0f668e613528633188141fafc60acc4 Mon Sep 17 00:00:00 2001 From: Yifei Huang Date: Wed, 23 Jan 2019 16:18:14 -0800 Subject: [PATCH 30/30] Fix UnsafeShuffleWriter (#15) * compiles ` * fix UnsafeShuffleWriter * remove unnecessary changes --- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 9eddfc924404..97f34bf46049 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -548,6 +548,10 @@ private CommittedPartition[] mergeSpillsWithPluggableWriter( ShufflePartitionWriter writer = mapOutputWriter.newPartitionWriter(partition); try { try (OutputStream partitionOutput = writer.openPartitionStream()) { + OutputStream partitionOutputStream = partitionOutput; + if (compressionCodec != null) { + partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput); + } for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { @@ -560,7 +564,7 @@ private CommittedPartition[] mergeSpillsWithPluggableWriter( partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - Utils.copyStream(partitionInputStream, partitionOutput, false, false); + Utils.copyStream(partitionInputStream, partitionOutputStream, false, false); } finally { partitionInputStream.close(); } @@ -621,7 +625,11 @@ private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter( partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } try (OutputStream partitionOutput = writer.openPartitionStream()) { - Utils.copyStream(partitionInputStream, partitionOutput, false, false); + OutputStream partitionOutputStream = partitionOutput; + if (compressionCodec != null) { + partitionOutputStream = compressionCodec.compressedOutputStream(partitionOutput); + } + Utils.copyStream(partitionInputStream, partitionOutputStream, false, false); } } catch (Exception e) { try { @@ -637,6 +645,7 @@ private CommittedPartition[] writeSingleSpillFileUsingPluggableWriter( writeMetrics.incBytesWritten(committedPartitions[partition].length()); } threwException = false; + mapOutputWriter.commitAllPartitions(); } catch (Exception e) { try { mapOutputWriter.abort(e);