diff --git a/tez-api/src/main/java/org/apache/tez/common/TezUtils.java b/tez-api/src/main/java/org/apache/tez/common/TezUtils.java index 51311ffd80..23811aa7f1 100644 --- a/tez-api/src/main/java/org/apache/tez/common/TezUtils.java +++ b/tez-api/src/main/java/org/apache/tez/common/TezUtils.java @@ -30,6 +30,7 @@ import com.google.protobuf.ByteString; import com.google.protobuf.CodedInputStream; +import org.apache.tez.runtime.api.TaskContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.classification.InterfaceAudience; @@ -120,6 +121,27 @@ public static Configuration createConfFromByteString(ByteString byteString) thro } } + public static Configuration createConfFromBaseConfAndPayload(TaskContext context) + throws IOException { + Configuration baseConf = context.getContainerConfiguration(); + Configuration configuration = new Configuration(baseConf); + UserPayload payload = context.getUserPayload(); + ByteString byteString = ByteString.copyFrom(payload.getPayload()); + try(SnappyInputStream uncompressIs = new SnappyInputStream(byteString.newInput())) { + DAGProtos.ConfigurationProto confProto = DAGProtos.ConfigurationProto.parseFrom(uncompressIs); + readConfFromPB(confProto, configuration); + return configuration; + } + } + + public static void addToConfFromByteString(Configuration configuration, ByteString byteString) + throws IOException { + try(SnappyInputStream uncompressIs = new SnappyInputStream(byteString.newInput())) { + DAGProtos.ConfigurationProto confProto = DAGProtos.ConfigurationProto.parseFrom(uncompressIs); + readConfFromPB(confProto, configuration); + } + } + /** * Convert an instance of {@link org.apache.tez.dag.api.UserPayload} to {@link * org.apache.hadoop.conf.Configuration} diff --git a/tez-api/src/main/java/org/apache/tez/runtime/api/InputInitializerContext.java b/tez-api/src/main/java/org/apache/tez/runtime/api/InputInitializerContext.java index ccfac46e21..7c9562e7c9 100644 --- a/tez-api/src/main/java/org/apache/tez/runtime/api/InputInitializerContext.java +++ b/tez-api/src/main/java/org/apache/tez/runtime/api/InputInitializerContext.java @@ -24,6 +24,7 @@ import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.classification.InterfaceStability.Unstable; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.Resource; import org.apache.tez.common.counters.TezCounters; @@ -49,7 +50,13 @@ public interface InputInitializerContext { * @return DAG name */ String getDAGName(); - + + /** + * Get vertex configuration + * @return Vertex configuration + */ + Configuration getVertexConfiguration(); + /** * Get the name of the input * @return Input name diff --git a/tez-api/src/main/java/org/apache/tez/runtime/api/TaskContext.java b/tez-api/src/main/java/org/apache/tez/runtime/api/TaskContext.java index dd2951a382..1ba1a90e3e 100644 --- a/tez-api/src/main/java/org/apache/tez/runtime/api/TaskContext.java +++ b/tez-api/src/main/java/org/apache/tez/runtime/api/TaskContext.java @@ -27,6 +27,7 @@ import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.classification.InterfaceStability.Unstable; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.tez.common.counters.TezCounters; import org.apache.tez.dag.api.UserPayload; @@ -62,6 +63,12 @@ public interface TaskContext { */ public int getTaskAttemptNumber(); + /** + * Get container configuration + * @return Container configuration + */ + public Configuration getContainerConfiguration(); + /** * Get the name of the DAG * @return the DAG name diff --git a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TezRootInputInitializerContextImpl.java b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TezRootInputInitializerContextImpl.java index 43764878b6..a994359354 100644 --- a/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TezRootInputInitializerContextImpl.java +++ b/tez-dag/src/main/java/org/apache/tez/dag/app/dag/impl/TezRootInputInitializerContextImpl.java @@ -23,6 +23,7 @@ import java.util.Set; import java.util.Objects; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.Resource; import org.apache.tez.common.counters.TezCounters; @@ -85,7 +86,12 @@ public UserPayload getInputUserPayload() { public UserPayload getUserPayload() { return this.input.getControllerDescriptor().getUserPayload(); } - + + @Override + public Configuration getVertexConfiguration() { + return vertex.getConf(); + } + @Override public int getNumTasks() { return vertex.getTotalTasks(); diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java index dbfdcb3843..d06a5f46a0 100644 --- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java +++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/common/MRInputAMSplitGenerator.java @@ -30,7 +30,6 @@ import org.apache.hadoop.classification.InterfaceStability.Evolving; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.mapred.JobConf; -import org.apache.hadoop.mapreduce.split.TezMapReduceSplitsGrouper; import org.apache.hadoop.security.UserGroupInformation; import org.apache.tez.common.TezUtils; import org.apache.tez.dag.api.VertexLocationHint; @@ -80,8 +79,8 @@ public List initialize() throws Exception { + sw.now(TimeUnit.MILLISECONDS)); } sw.reset().start(); - Configuration conf = TezUtils.createConfFromByteString(userPayloadProto - .getConfigurationBytes()); + Configuration conf = new JobConf(getContext().getVertexConfiguration()); + TezUtils.addToConfFromByteString(conf, userPayloadProto.getConfigurationBytes()); sendSerializedEvents = conf.getBoolean( MRJobConfig.MR_TEZ_INPUT_INITIALIZER_SERIALIZE_EVENT_PAYLOAD, diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/base/MRInputBase.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/base/MRInputBase.java index d8c531ea84..ccae0b1964 100644 --- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/base/MRInputBase.java +++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/input/base/MRInputBase.java @@ -72,8 +72,9 @@ public List initialize() throws IOException { boolean isGrouped = mrUserPayload.getGroupingEnabled(); Preconditions.checkArgument(mrUserPayload.hasSplits() == false, "Split information not expected in " + this.getClass().getName()); - Configuration conf = TezUtils - .createConfFromByteString(mrUserPayload.getConfigurationBytes()); + + Configuration conf = new JobConf(getContext().getContainerConfiguration()); + TezUtils.addToConfFromByteString(conf, mrUserPayload.getConfigurationBytes()); this.jobConf = new JobConf(conf); useNewApi = this.jobConf.getUseNewMapper(); if (isGrouped) { diff --git a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/output/MROutput.java b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/output/MROutput.java index 18047ebf09..950e629907 100644 --- a/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/output/MROutput.java +++ b/tez-mapreduce/src/main/java/org/apache/tez/mapreduce/output/MROutput.java @@ -29,6 +29,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.ByteString; import org.apache.tez.common.Preconditions; import org.apache.commons.lang.StringUtils; import org.apache.hadoop.mapreduce.lib.output.LazyOutputFormat; @@ -398,8 +399,9 @@ protected List initializeBase() throws IOException, InterruptedException taskNumberFormat.setGroupingUsed(false); nonTaskNumberFormat.setMinimumIntegerDigits(3); nonTaskNumberFormat.setGroupingUsed(false); - Configuration conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); - this.jobConf = new JobConf(conf); + UserPayload userPayload = getContext().getUserPayload(); + this.jobConf = new JobConf(getContext().getContainerConfiguration()); + TezUtils.addToConfFromByteString(this.jobConf, ByteString.copyFrom(userPayload.getPayload())); // Add tokens to the jobConf - in case they are accessed within the RW / OF jobConf.getCredentials().mergeAll(UserGroupInformation.getCurrentUser().getCredentials()); this.isMapperOutput = jobConf.getBoolean(MRConfig.IS_MAP_PROCESSOR, diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java index 369afbe6b3..83c28dd7bb 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/TezTestUtils.java @@ -17,6 +17,7 @@ */ package org.apache.tez.mapreduce; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.yarn.api.records.ApplicationId; import org.apache.hadoop.yarn.api.records.Resource; import org.apache.tez.common.counters.TezCounters; @@ -59,10 +60,12 @@ public static class TezRootInputInitializerContextForTest implements private final ApplicationId appId; private final UserPayload payload; + private final Configuration vertexConfig; - public TezRootInputInitializerContextForTest(UserPayload payload) throws IOException { + public TezRootInputInitializerContextForTest(UserPayload payload, Configuration vertexConfig) throws IOException { appId = ApplicationId.newInstance(1000, 200); this.payload = payload == null ? UserPayload.create(null) : payload; + this.vertexConfig = vertexConfig; } @Override @@ -75,6 +78,11 @@ public String getDAGName() { return "FakeDAG"; } + @Override + public Configuration getVertexConfiguration() { + return vertexConfig; + } + @Override public String getInputName() { return "MRInput"; diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputAMSplitGenerator.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputAMSplitGenerator.java index 6cf2700564..9f6ac3b74f 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputAMSplitGenerator.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputAMSplitGenerator.java @@ -96,7 +96,7 @@ private void testGroupSplitsAndSortSplits(boolean groupSplitsEnabled, UserPayload userPayload = dataSource.getInputDescriptor().getUserPayload(); InputInitializerContext context = - new TezTestUtils.TezRootInputInitializerContextForTest(userPayload); + new TezTestUtils.TezRootInputInitializerContextForTest(userPayload, new Configuration(false)); MRInputAMSplitGenerator splitGenerator = new MRInputAMSplitGenerator(context); diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java index 3772cde946..4aaa7e2e76 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/common/TestMRInputSplitDistributor.java @@ -70,7 +70,8 @@ public void testSerializedPayload() throws IOException { UserPayload userPayload = UserPayload.create(payloadProto.build().toByteString().asReadOnlyByteBuffer()); - InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload); + InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload, + new Configuration(false)); MRInputSplitDistributor splitDist = new MRInputSplitDistributor(context); List events = splitDist.initialize(); @@ -119,7 +120,8 @@ public void testDeserializedPayload() throws IOException { UserPayload userPayload = UserPayload.create(payloadProto.build().toByteString().asReadOnlyByteBuffer()); - InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload); + InputInitializerContext context = new TezTestUtils.TezRootInputInitializerContextForTest(userPayload, + new Configuration(false)); MRInputSplitDistributor splitDist = new MRInputSplitDistributor(context); List events = splitDist.initialize(); diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MRInputForTest.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MRInputForTest.java new file mode 100644 index 0000000000..0d1d24ff6f --- /dev/null +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MRInputForTest.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.mapreduce.input; + +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.runtime.api.InputContext; + +/** + * This is used for inspecting jobConf in test. + */ +public class MRInputForTest extends MRInput { + public MRInputForTest(InputContext inputContext, int numPhysicalInputs) { + super(inputContext, numPhysicalInputs); + } + + public Configuration getConfiguration() { + return jobConf; + } +} diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MultiMRInputForTest.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MultiMRInputForTest.java new file mode 100644 index 0000000000..f0f0a77aa7 --- /dev/null +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/MultiMRInputForTest.java @@ -0,0 +1,35 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.tez.mapreduce.input; + +import org.apache.hadoop.conf.Configuration; +import org.apache.tez.runtime.api.InputContext; + +/** + * This is used for inspecting jobConf in test. + */ +public class MultiMRInputForTest extends MultiMRInput { + public MultiMRInputForTest(InputContext inputContext, int numPhysicalInputs) { + super(inputContext, numPhysicalInputs); + } + + public Configuration getConfiguration() { + return jobConf; + } +} diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMRInput.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMRInput.java index 9109cd9c47..5ca5c26619 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMRInput.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMRInput.java @@ -69,6 +69,7 @@ public void test0PhysicalInputs() throws IOException { doReturn(1).when(inputContext).getTaskIndex(); doReturn(1).when(inputContext).getTaskAttemptNumber(); doReturn(new TezCounters()).when(inputContext).getCounters(); + doReturn(new JobConf(false)).when(inputContext).getContainerConfiguration(); MRInput mrInput = new MRInput(inputContext, 0); @@ -120,6 +121,7 @@ public void testAttributesInJobConf() throws Exception { doReturn(TEST_ATTRIBUTES_INPUT_NAME).when(inputContext).getSourceVertexName(); doReturn(TEST_ATTRIBUTES_APPLICATION_ID).when(inputContext).getApplicationId(); doReturn(TEST_ATTRIBUTES_UNIQUE_IDENTIFIER).when(inputContext).getUniqueIdentifier(); + doReturn(new Configuration(false)).when(inputContext).getContainerConfiguration(); DataSourceDescriptor dsd = MRInput.createConfigBuilder(new Configuration(false), @@ -147,6 +149,43 @@ public void testAttributesInJobConf() throws Exception { assertTrue(TestInputFormat.invoked.get()); } + @Test(timeout = 5000) + public void testConfigMerge() throws Exception { + JobConf jobConf = new JobConf(false); + jobConf.set("payload-key", "payload-value"); + + Configuration localConfig = new Configuration(false); + localConfig.set("local-key", "local-value"); + + InputContext inputContext = mock(InputContext.class); + + DataSourceDescriptor dsd = MRInput.createConfigBuilder(jobConf, + TestInputFormat.class).groupSplits(false).build(); + + doReturn(dsd.getInputDescriptor().getUserPayload()).when(inputContext).getUserPayload(); + doReturn(TEST_ATTRIBUTES_DAG_INDEX).when(inputContext).getDagIdentifier(); + doReturn(TEST_ATTRIBUTES_VERTEX_INDEX).when(inputContext).getTaskVertexIndex(); + doReturn(TEST_ATTRIBUTES_TASK_INDEX).when(inputContext).getTaskIndex(); + doReturn(TEST_ATTRIBUTES_TASK_ATTEMPT_INDEX).when(inputContext).getTaskAttemptNumber(); + doReturn(TEST_ATTRIBUTES_INPUT_INDEX).when(inputContext).getInputIndex(); + doReturn(TEST_ATTRIBUTES_DAG_ATTEMPT_NUMBER).when(inputContext).getDAGAttemptNumber(); + doReturn(TEST_ATTRIBUTES_DAG_NAME).when(inputContext).getDAGName(); + doReturn(TEST_ATTRIBUTES_VERTEX_NAME).when(inputContext).getTaskVertexName(); + doReturn(TEST_ATTRIBUTES_INPUT_NAME).when(inputContext).getSourceVertexName(); + doReturn(TEST_ATTRIBUTES_APPLICATION_ID).when(inputContext).getApplicationId(); + doReturn(TEST_ATTRIBUTES_UNIQUE_IDENTIFIER).when(inputContext).getUniqueIdentifier(); + doReturn(localConfig).when(inputContext).getContainerConfiguration(); + doReturn(new TezCounters()).when(inputContext).getCounters(); + + MRInputForTest input = new MRInputForTest(inputContext, 1); + input.initialize(); + + Configuration mergedConfig = input.getConfiguration(); + + assertEquals("local-value", mergedConfig.get("local-key")); + assertEquals("payload-value", mergedConfig.get("payload-key")); + } + /** * Test class to verify */ diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMultiMRInput.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMultiMRInput.java index 8d77a0539b..bd6e891bd2 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMultiMRInput.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/input/TestMultiMRInput.java @@ -102,7 +102,7 @@ public void test0PhysicalInputs() throws Exception { jobConf.setInputFormat(org.apache.hadoop.mapred.SequenceFileInputFormat.class); FileInputFormat.setInputPaths(jobConf, workDir); - InputContext inputContext = createTezInputContext(jobConf); + InputContext inputContext = createTezInputContext(jobConf, new Configuration(false)); MultiMRInput mMrInput = new MultiMRInput(inputContext, 0); @@ -121,6 +121,25 @@ public void test0PhysicalInputs() throws Exception { } } + @Test(timeout = 5000) + public void testConfigMerge() throws Exception { + JobConf jobConf = new JobConf(false); + jobConf.set("payload-key", "payload-value"); + + Configuration localConfig = new Configuration(false); + localConfig.set("local-key", "local-value"); + + InputContext inputContext = createTezInputContext(jobConf, localConfig); + + MultiMRInputForTest input = new MultiMRInputForTest(inputContext, 1); + input.initialize(); + + Configuration mergedConfig = input.getConfiguration(); + + assertEquals("local-value", mergedConfig.get("local-key")); + assertEquals("payload-value", mergedConfig.get("payload-key")); + } + @Test(timeout = 5000) public void testSingleSplit() throws Exception { @@ -129,7 +148,7 @@ public void testSingleSplit() throws Exception { jobConf.setInputFormat(org.apache.hadoop.mapred.SequenceFileInputFormat.class); FileInputFormat.setInputPaths(jobConf, workDir); - InputContext inputContext = createTezInputContext(jobConf); + InputContext inputContext = createTezInputContext(jobConf, new Configuration(false)); MultiMRInput input = new MultiMRInput(inputContext, 1); input.initialize(); @@ -180,7 +199,7 @@ public void testNewFormatSplits() throws Exception { splitProto.toByteString().asReadOnlyByteBuffer()); // Create input context. - InputContext inputContext = createTezInputContext(conf); + InputContext inputContext = createTezInputContext(conf, new Configuration(false)); // Create the MR input object and process the event MultiMRInput input = new MultiMRInput(inputContext, 1); @@ -198,7 +217,7 @@ public void testMultipleSplits() throws Exception { jobConf.setInputFormat(org.apache.hadoop.mapred.SequenceFileInputFormat.class); FileInputFormat.setInputPaths(jobConf, workDir); - InputContext inputContext = createTezInputContext(jobConf); + InputContext inputContext = createTezInputContext(jobConf, new Configuration(false)); MultiMRInput input = new MultiMRInput(inputContext, 2); input.initialize(); @@ -265,7 +284,7 @@ public void testExtraEvents() throws Exception { jobConf.setInputFormat(org.apache.hadoop.mapred.SequenceFileInputFormat.class); FileInputFormat.setInputPaths(jobConf, workDir); - InputContext inputContext = createTezInputContext(jobConf); + InputContext inputContext = createTezInputContext(jobConf, new Configuration(false)); MultiMRInput input = new MultiMRInput(inputContext, 1); input.initialize(); @@ -308,10 +327,10 @@ private LinkedHashMap createSplits(int splitCount, Path work return data; } - private InputContext createTezInputContext(Configuration conf) throws Exception { + private InputContext createTezInputContext(Configuration payloadConf, Configuration baseConf) throws Exception { MRInputUserPayloadProto.Builder builder = MRInputUserPayloadProto.newBuilder(); builder.setGroupingEnabled(false); - builder.setConfigurationBytes(TezUtils.createByteStringFromConf(conf)); + builder.setConfigurationBytes(TezUtils.createByteStringFromConf(payloadConf)); byte[] payload = builder.build().toByteArray(); ApplicationId applicationId = ApplicationId.newInstance(10000, 1); @@ -330,6 +349,7 @@ private InputContext createTezInputContext(Configuration conf) throws Exception doReturn(UUID.randomUUID().toString()).when(inputContext).getUniqueIdentifier(); doReturn("taskVertexName").when(inputContext).getTaskVertexName(); doReturn(UserPayload.create(ByteBuffer.wrap(payload))).when(inputContext).getUserPayload(); + doReturn(baseConf).when(inputContext).getContainerConfiguration(); return inputContext; } diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutput.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutput.java index c60ca228b3..bfc09dc9b8 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutput.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutput.java @@ -94,7 +94,8 @@ public void testNewAPI_TextOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); @@ -109,6 +110,27 @@ public void testNewAPI_TextOutputFormat() throws Exception { assertEquals(FileOutputCommitter.class, output.committer.getClass()); } + @Test + public void testMergeConfig() throws Exception { + String outputPath = "/tmp/output"; + Configuration localConf = new Configuration(false); + localConf.set("local-key", "local-value"); + DataSinkDescriptor dataSink = MROutput + .createConfigBuilder(localConf, org.apache.hadoop.mapred.TextOutputFormat.class, outputPath) + .build(); + + Configuration baseConf = new Configuration(false); + baseConf.set("base-key", "base-value"); + + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), baseConf); + MROutput output = new MROutput(outputContext, 2); + output.initialize(); + + Configuration mergedConf = output.jobConf; + assertEquals("local-value", mergedConf.get("local-key")); + assertEquals("base-value", mergedConf.get("base-key")); + } + @Test(timeout = 5000) public void testOldAPI_TextOutputFormat() throws Exception { Configuration conf = new Configuration(); @@ -119,7 +141,8 @@ public void testOldAPI_TextOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); @@ -144,7 +167,8 @@ public void testNewAPI_SequenceFileOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); assertEquals(true, output.useNewApi); @@ -169,7 +193,8 @@ public void testOldAPI_SequenceFileOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); assertEquals(false, output.useNewApi); @@ -194,7 +219,8 @@ public void testNewAPI_WorkOutputPathOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); @@ -220,7 +246,8 @@ public void testOldAPI_WorkOutputPathOutputFormat() throws Exception { tmpDir.getPath()) .build(); - OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload()); + OutputContext outputContext = createMockOutputContext(dataSink.getOutputDescriptor().getUserPayload(), + new Configuration(false)); MROutput output = new MROutput(outputContext, 2); output.initialize(); @@ -235,7 +262,7 @@ public void testOldAPI_WorkOutputPathOutputFormat() throws Exception { assertEquals(org.apache.hadoop.mapred.FileOutputCommitter.class, output.committer.getClass()); } - private OutputContext createMockOutputContext(UserPayload payload) { + private OutputContext createMockOutputContext(UserPayload payload, Configuration baseConf) { OutputContext outputContext = mock(OutputContext.class); ApplicationId appId = ApplicationId.newInstance(System.currentTimeMillis(), 1); when(outputContext.getUserPayload()).thenReturn(payload); @@ -243,6 +270,7 @@ private OutputContext createMockOutputContext(UserPayload payload) { when(outputContext.getTaskVertexIndex()).thenReturn(1); when(outputContext.getTaskAttemptNumber()).thenReturn(1); when(outputContext.getCounters()).thenReturn(new TezCounters()); + when(outputContext.getContainerConfiguration()).thenReturn(baseConf); return outputContext; } diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutputLegacy.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutputLegacy.java index 01b5c84e70..60596be89d 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutputLegacy.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMROutputLegacy.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.NullWritable; import org.apache.hadoop.io.Text; @@ -182,6 +183,7 @@ private OutputContext createMockOutputContext(UserPayload payload) { when(outputContext.getTaskVertexIndex()).thenReturn(1); when(outputContext.getTaskAttemptNumber()).thenReturn(1); when(outputContext.getCounters()).thenReturn(new TezCounters()); + when(outputContext.getContainerConfiguration()).thenReturn(new Configuration(false)); return outputContext; } } diff --git a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMultiMROutput.java b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMultiMROutput.java index c8eca16027..2662827678 100644 --- a/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMultiMROutput.java +++ b/tez-mapreduce/src/test/java/org/apache/tez/mapreduce/output/TestMultiMROutput.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.Text; import org.apache.hadoop.mapred.JobConf; import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter; @@ -107,6 +108,34 @@ public void testInvalidBasePath() throws Exception { } } + @Test + public void testMergeConf() throws Exception { + JobConf payloadConf = new JobConf(); + payloadConf.set("local-key", "local-value"); + DataSinkDescriptor dataSink = MultiMROutput.createConfigBuilder( + payloadConf, SequenceFileOutputFormat.class, "/output", false).build(); + + Configuration baseConf = new Configuration(false); + baseConf.set("base-key", "base-value"); + + OutputContext outputContext = mock(OutputContext.class); + ApplicationId appId = ApplicationId.newInstance(System.currentTimeMillis(), 1); + when(outputContext.getUserPayload()).thenReturn(dataSink.getOutputDescriptor().getUserPayload()); + when(outputContext.getApplicationId()).thenReturn(appId); + when(outputContext.getTaskVertexIndex()).thenReturn(1); + when(outputContext.getTaskAttemptNumber()).thenReturn(1); + when(outputContext.getCounters()).thenReturn(new TezCounters()); + when(outputContext.getStatisticsReporter()).thenReturn(mock(OutputStatisticsReporter.class)); + when(outputContext.getContainerConfiguration()).thenReturn(baseConf); + + MultiMROutput output = new MultiMROutput(outputContext, 2); + output.initialize(); + + Configuration mergedConf = output.jobConf; + assertEquals("base-value", mergedConf.get("base-key")); + assertEquals("local-value", mergedConf.get("local-key")); + } + private OutputContext createMockOutputContext(UserPayload payload) { OutputContext outputContext = mock(OutputContext.class); ApplicationId appId = ApplicationId.newInstance(System.currentTimeMillis(), 1); @@ -117,6 +146,7 @@ private OutputContext createMockOutputContext(UserPayload payload) { when(outputContext.getCounters()).thenReturn(new TezCounters()); when(outputContext.getStatisticsReporter()).thenReturn( mock(OutputStatisticsReporter.class)); + when(outputContext.getContainerConfiguration()).thenReturn(new Configuration(false)); return outputContext; } diff --git a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezTaskContextImpl.java b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezTaskContextImpl.java index dccde823e7..a47dac1e0a 100644 --- a/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezTaskContextImpl.java +++ b/tez-runtime-internals/src/main/java/org/apache/tez/runtime/api/impl/TezTaskContextImpl.java @@ -56,6 +56,7 @@ public abstract class TezTaskContextImpl implements TaskContext, Closeable { protected final String taskVertexName; protected final TezTaskAttemptID taskAttemptID; private final TezCounters counters; + private Configuration configuration; private String[] workDirs; private String uniqueIdentifier; protected final LogicalIOProcessorRuntimeTask runtimeTask; @@ -91,6 +92,7 @@ public TezTaskContextImpl(Configuration conf, String[] workDirs, int appAttemptN Objects.requireNonNull(descriptor, "descriptor is null"); Objects.requireNonNull(sharedExecutor, "sharedExecutor is null"); this.dagName = dagName; + this.configuration = conf; this.taskVertexName = taskVertexName; this.taskAttemptID = taskAttemptID; this.counters = counters; @@ -135,6 +137,11 @@ public int getTaskAttemptNumber() { return taskAttemptID.getId(); } + @Override + public Configuration getContainerConfiguration() { + return configuration; + } + @Override public String getDAGName() { return dagName; diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/OrderedGroupedKVInput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/OrderedGroupedKVInput.java index c1879bc364..2b405bb343 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/OrderedGroupedKVInput.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/OrderedGroupedKVInput.java @@ -28,6 +28,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import com.google.common.annotations.VisibleForTesting; +import org.apache.tez.common.TezUtils; import org.apache.tez.runtime.api.ProgressFailedException; import org.apache.tez.runtime.library.api.IOInterruptedException; import org.apache.tez.runtime.library.common.Constants; @@ -37,7 +38,6 @@ import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.io.RawComparator; -import org.apache.tez.common.TezUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.counters.TaskCounter; import org.apache.tez.common.counters.TezCounter; @@ -97,7 +97,7 @@ public OrderedGroupedKVInput(InputContext inputContext, int numPhysicalInputs) { @Override public synchronized List initialize() throws IOException { - this.conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); + this.conf = TezUtils.createConfFromBaseConfAndPayload(getContext()); if (this.getNumPhysicalInputs() == 0) { getContext().requestInitialMemory(0l, null); diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/UnorderedKVInput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/UnorderedKVInput.java index 401066dfc4..1db786995a 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/UnorderedKVInput.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/input/UnorderedKVInput.java @@ -24,6 +24,7 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.tez.common.TezUtils; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.runtime.api.ProgressFailedException; import org.apache.tez.runtime.library.common.Constants; @@ -36,7 +37,6 @@ import org.apache.hadoop.io.compress.CompressionCodec; import org.apache.hadoop.io.compress.DefaultCodec; import org.apache.hadoop.util.ReflectionUtils; -import org.apache.tez.common.TezUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.counters.TaskCounter; import org.apache.tez.common.counters.TezCounter; @@ -88,7 +88,7 @@ public UnorderedKVInput(InputContext inputContext, int numPhysicalInputs) { @Override public synchronized List initialize() throws Exception { Preconditions.checkArgument(getNumPhysicalInputs() != -1, "Number of Inputs has not been set"); - this.conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); + this.conf = TezUtils.createConfFromBaseConfAndPayload(getContext()); if (getNumPhysicalInputs() == 0) { getContext().requestInitialMemory(0l, null); diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/OrderedPartitionedKVOutput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/OrderedPartitionedKVOutput.java index 86c20dd9e3..676fe17a5f 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/OrderedPartitionedKVOutput.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/OrderedPartitionedKVOutput.java @@ -30,6 +30,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.Lists; +import org.apache.tez.common.TezUtils; import org.apache.tez.runtime.library.conf.OrderedPartitionedKVOutputConfig.SorterImpl; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,7 +41,6 @@ import org.apache.hadoop.fs.RawLocalFileSystem; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; -import org.apache.tez.common.TezUtils; import org.apache.tez.dag.api.TezConfiguration; import org.apache.tez.runtime.api.AbstractLogicalOutput; import org.apache.tez.runtime.api.Event; @@ -90,7 +90,7 @@ public OrderedPartitionedKVOutput(OutputContext outputContext, int numPhysicalOu @Override public synchronized List initialize() throws IOException { this.startTime = System.nanoTime(); - this.conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); + this.conf = TezUtils.createConfFromBaseConfAndPayload(getContext()); this.localFs = (RawLocalFileSystem) FileSystem.getLocal(conf).getRaw(); // Initializing this parametr in this conf since it is used in multiple diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedKVOutput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedKVOutput.java index 85368f6ea9..e7a4429d95 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedKVOutput.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedKVOutput.java @@ -25,6 +25,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.tez.common.TezUtils; import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -32,7 +33,6 @@ import org.apache.hadoop.classification.InterfaceAudience.Private; import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.conf.Configuration; -import org.apache.tez.common.TezUtils; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.counters.TaskCounter; @@ -62,8 +62,9 @@ public class UnorderedKVOutput extends AbstractLogicalOutput { @VisibleForTesting UnorderedPartitionedKVWriter kvWriter; - - private Configuration conf; + + @VisibleForTesting + Configuration conf; private MemoryUpdateCallbackHandler memoryUpdateCallbackHandler; private final AtomicBoolean isStarted = new AtomicBoolean(false); @@ -76,7 +77,7 @@ public UnorderedKVOutput(OutputContext outputContext, int numPhysicalOutputs) { @Override public synchronized List initialize() throws Exception { - this.conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); + this.conf = TezUtils.createConfFromBaseConfAndPayload(getContext()); this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, getContext().getWorkDirs()); diff --git a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedPartitionedKVOutput.java b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedPartitionedKVOutput.java index 5e223d6c40..439b732db5 100644 --- a/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedPartitionedKVOutput.java +++ b/tez-runtime-library/src/main/java/org/apache/tez/runtime/library/output/UnorderedPartitionedKVOutput.java @@ -26,14 +26,15 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.apache.tez.common.Preconditions; +import com.google.common.annotations.VisibleForTesting; +import org.apache.tez.common.TezUtils; import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.hadoop.classification.InterfaceAudience; import org.apache.hadoop.classification.InterfaceAudience.Public; import org.apache.hadoop.conf.Configuration; -import org.apache.tez.common.TezUtils; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezRuntimeFrameworkConfigs; import org.apache.tez.common.counters.TaskCounter; @@ -57,7 +58,8 @@ public class UnorderedPartitionedKVOutput extends AbstractLogicalOutput { private static final Logger LOG = LoggerFactory.getLogger(UnorderedPartitionedKVOutput.class); - private Configuration conf; + @VisibleForTesting + Configuration conf; private MemoryUpdateCallbackHandler memoryUpdateCallbackHandler; private UnorderedPartitionedKVWriter kvWriter; private final AtomicBoolean isStarted = new AtomicBoolean(false); @@ -68,7 +70,7 @@ public UnorderedPartitionedKVOutput(OutputContext outputContext, int numPhysical @Override public synchronized List initialize() throws Exception { - this.conf = TezUtils.createConfFromUserPayload(getContext().getUserPayload()); + this.conf = TezUtils.createConfFromBaseConfAndPayload(getContext()); this.conf.setStrings(TezRuntimeFrameworkConfigs.LOCAL_DIRS, getContext().getWorkDirs()); this.conf.setInt(TezRuntimeFrameworkConfigs.TEZ_RUNTIME_NUM_EXPECTED_PARTITIONS, getNumPhysicalOutputs()); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/input/TestOrderedGroupedKVInput.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/input/TestOrderedGroupedKVInput.java index d4be80211a..56b6805a63 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/input/TestOrderedGroupedKVInput.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/input/TestOrderedGroupedKVInput.java @@ -14,6 +14,7 @@ package org.apache.tez.runtime.library.input; +import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doReturn; @@ -57,6 +58,34 @@ public void testInterruptWhileAwaitingInput() throws IOException, TezException { } + @Test + public void testMergeConfig() throws IOException, TezException { + Configuration baseConf = new Configuration(false); + baseConf.set("base-key", "base-value"); + + Configuration payloadConf = new Configuration(false); + payloadConf.set("local-key", "local-value"); + + InputContext inputContext = mock(InputContext.class); + + UserPayload payLoad = TezUtils.createUserPayloadFromConf(payloadConf); + String[] workingDirs = new String[]{"workDir1"}; + TezCounters counters = new TezCounters(); + + + doReturn(payLoad).when(inputContext).getUserPayload(); + doReturn(workingDirs).when(inputContext).getWorkDirs(); + doReturn(counters).when(inputContext).getCounters(); + doReturn(baseConf).when(inputContext).getContainerConfiguration(); + + OrderedGroupedKVInput input = new OrderedGroupedKVInput(inputContext, 1); + input.initialize(); + + Configuration mergedConf = input.conf; + assertEquals("base-value", mergedConf.get("base-key")); + assertEquals("local-value", mergedConf.get("local-key")); + } + private InputContext createMockInputContext() throws IOException { InputContext inputContext = mock(InputContext.class); @@ -70,6 +99,7 @@ private InputContext createMockInputContext() throws IOException { doReturn(workingDirs).when(inputContext).getWorkDirs(); doReturn(200 * 1024 * 1024l).when(inputContext).getTotalMemoryAvailableToTask(); doReturn(counters).when(inputContext).getCounters(); + doReturn(new Configuration(false)).when(inputContext).getContainerConfiguration(); doAnswer(new Answer() { @Override diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java index 573d53e7b1..b81c2bd036 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/OutputTestHelpers.java @@ -51,10 +51,12 @@ static OutputContext createOutputContext() throws IOException { doReturn(200 * 1024 * 1024l).when(outputContext).getTotalMemoryAvailableToTask(); doReturn(counters).when(outputContext).getCounters(); doReturn(statsReporter).when(outputContext).getStatisticsReporter(); + doReturn(new Configuration(false)).when(outputContext).getContainerConfiguration(); return outputContext; } - static OutputContext createOutputContext(Configuration conf, Path workingDir) throws IOException { + static OutputContext createOutputContext(Configuration conf, Configuration userPayloadConf, Path workingDir) + throws IOException { OutputContext ctx = mock(OutputContext.class); doAnswer(new Answer() { @Override public Void answer(InvocationOnMock invocation) throws Throwable { @@ -65,7 +67,8 @@ static OutputContext createOutputContext(Configuration conf, Path workingDir) th return null; } }).when(ctx).requestInitialMemory(anyLong(), any(MemoryUpdateCallback.class)); - doReturn(TezUtils.createUserPayloadFromConf(conf)).when(ctx).getUserPayload(); + doReturn(conf).when(ctx).getContainerConfiguration(); + doReturn(TezUtils.createUserPayloadFromConf(userPayloadConf)).when(ctx).getUserPayload(); doReturn("destinationVertex").when(ctx).getDestinationVertexName(); doReturn("UUID").when(ctx).getUniqueIdentifier(); doReturn(new String[] { workingDir.toString() }).when(ctx).getWorkDirs(); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileSortedOutput.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileSortedOutput.java index 77620258dc..2c9c3b2ace 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileSortedOutput.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileSortedOutput.java @@ -44,7 +44,6 @@ import org.apache.tez.runtime.library.common.sort.impl.dflt.DefaultSorter; import org.apache.tez.runtime.library.conf.OrderedPartitionedKVOutputConfig.SorterImpl; import org.apache.tez.runtime.library.partitioner.HashPartitioner; -import org.apache.tez.runtime.library.common.shuffle.ShuffleUtils; import org.apache.tez.runtime.library.shuffle.impl.ShuffleUserPayloads; import org.junit.After; import org.junit.Assert; @@ -378,6 +377,7 @@ public void testAllEmptyPartition() throws Exception { private OutputContext createTezOutputContext() throws IOException { String[] workingDirs = { workingDir.toString() }; + Configuration localConf = new Configuration(false); UserPayload payLoad = TezUtils.createUserPayloadFromConf(conf); DataOutputBuffer serviceProviderMetaData = new DataOutputBuffer(); serviceProviderMetaData.writeInt(PORT); @@ -400,6 +400,7 @@ private OutputContext createTezOutputContext() throws IOException { OutputContext context = mock(OutputContext.class); + doReturn(localConf).when(context).getContainerConfiguration(); doReturn(counters).when(context).getCounters(); doReturn(workingDirs).when(context).getWorkDirs(); doReturn(payLoad).when(context).getUserPayload(); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileUnorderedKVOutput.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileUnorderedKVOutput.java index 393ac2e71d..963300cd40 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileUnorderedKVOutput.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOnFileUnorderedKVOutput.java @@ -128,7 +128,7 @@ public void testGeneratedDataMovementEvent() throws Exception { conf.set(TezRuntimeConfiguration.TEZ_RUNTIME_VALUE_CLASS, IntWritable.class.getName()); TezSharedExecutor sharedExecutor = new TezSharedExecutor(conf); - OutputContext outputContext = createOutputContext(conf, sharedExecutor); + OutputContext outputContext = createOutputContext(conf, new Configuration(false), sharedExecutor); UnorderedKVOutput kvOutput = new UnorderedKVOutput(outputContext, 1); @@ -161,6 +161,26 @@ public void testGeneratedDataMovementEvent() throws Exception { sharedExecutor.shutdownNow(); } + @Test + public void testMergeConfig() throws Exception { + Configuration baseConf = new Configuration(false); + baseConf.set("local-key", "local-value"); + + Configuration payloadConf = new Configuration(false); + payloadConf.set("base-key", "base-value"); + + TezSharedExecutor sharedExecutor = new TezSharedExecutor(baseConf); + OutputContext outputContext = createOutputContext(payloadConf, baseConf, sharedExecutor); + + UnorderedKVOutput kvOutput = new UnorderedKVOutput(outputContext, 1); + + kvOutput.initialize(); + + Configuration mergedConf = kvOutput.conf; + assertEquals("local-value", mergedConf.get("local-key")); + assertEquals("base-value", mergedConf.get("base-key")); + } + @Test(timeout = 30000) @SuppressWarnings("unchecked") public void testWithPipelinedShuffle() throws Exception { @@ -173,7 +193,7 @@ public void testWithPipelinedShuffle() throws Exception { conf.setInt(TezRuntimeConfiguration.TEZ_RUNTIME_UNORDERED_OUTPUT_BUFFER_SIZE_MB, 1); TezSharedExecutor sharedExecutor = new TezSharedExecutor(conf); - OutputContext outputContext = createOutputContext(conf, sharedExecutor); + OutputContext outputContext = createOutputContext(conf, new Configuration(false), sharedExecutor); UnorderedKVOutput kvOutput = new UnorderedKVOutput(outputContext, 1); @@ -211,8 +231,8 @@ public void testWithPipelinedShuffle() throws Exception { sharedExecutor.shutdownNow(); } - private OutputContext createOutputContext(Configuration conf, TezSharedExecutor sharedExecutor) - throws IOException { + private OutputContext createOutputContext(Configuration payloadConf, Configuration baseConf, + TezSharedExecutor sharedExecutor) throws IOException { int appAttemptNumber = 1; TezUmbilical tezUmbilical = mock(TezUmbilical.class); String dagName = "currentDAG"; @@ -222,7 +242,7 @@ private OutputContext createOutputContext(Configuration conf, TezSharedExecutor TezVertexID vertexID = TezVertexID.getInstance(dagID, 1); TezTaskID taskID = TezTaskID.getInstance(vertexID, 1); TezTaskAttemptID taskAttemptID = TezTaskAttemptID.getInstance(taskID, 1); - UserPayload userPayload = TezUtils.createUserPayloadFromConf(conf); + UserPayload userPayload = TezUtils.createUserPayloadFromConf(payloadConf); TaskSpec mockSpec = mock(TaskSpec.class); when(mockSpec.getInputs()).thenReturn(Collections.singletonList(mock(InputSpec.class))); @@ -237,17 +257,17 @@ private OutputContext createOutputContext(Configuration conf, TezSharedExecutor ByteBuffer bb = ByteBuffer.allocate(4); bb.putInt(shufflePort); bb.position(0); - AuxiliaryServiceHelper.setServiceDataIntoEnv(conf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, + AuxiliaryServiceHelper.setServiceDataIntoEnv(payloadConf.get(TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID, TezConfiguration.TEZ_AM_SHUFFLE_AUXILIARY_SERVICE_ID_DEFAULT), bb, auxEnv); OutputDescriptor outputDescriptor = mock(OutputDescriptor.class); when(outputDescriptor.getClassName()).thenReturn("OutputDescriptor"); - OutputContext realOutputContext = new TezOutputContextImpl(conf, new String[] {workDir.toString()}, + OutputContext realOutputContext = new TezOutputContextImpl(baseConf, new String[] {workDir.toString()}, appAttemptNumber, tezUmbilical, dagName, taskVertexName, destinationVertexName, -1, taskAttemptID, 0, userPayload, runtimeTask, - null, auxEnv, new MemoryDistributor(1, 1, conf) , outputDescriptor, null, + null, auxEnv, new MemoryDistributor(1, 1, payloadConf), outputDescriptor, null, new ExecutionContextImpl("localhost"), 2048, new TezSharedExecutor(defaultConf)); verify(runtimeTask, times(1)).addAndGetTezCounter(destinationVertexName); verify(runtimeTask, times(1)).getTaskStatistics(); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOrderedPartitionedKVOutput2.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOrderedPartitionedKVOutput2.java index f226b7c385..29ce890309 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOrderedPartitionedKVOutput2.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestOrderedPartitionedKVOutput2.java @@ -69,7 +69,7 @@ public void cleanup() throws IOException { @Test(timeout = 5000) public void testNonStartedOutput() throws IOException { - OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir); + OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, conf, workingDir); int numPartitions = 10; OrderedPartitionedKVOutput output = new OrderedPartitionedKVOutput(outputContext, numPartitions); output.initialize(); @@ -94,9 +94,24 @@ public void testNonStartedOutput() throws IOException { } } + @Test(timeout = 5000) + public void testConfigMerge() throws IOException { + Configuration localConf = new Configuration(conf); + localConf.set("config-from-local", "config-from-local-value"); + Configuration payload = new Configuration(false); + payload.set("config-from-payload", "config-from-payload-value"); + OutputContext outputContext = OutputTestHelpers.createOutputContext(localConf, payload, workingDir); + int numPartitions = 10; + OrderedPartitionedKVOutput output = new OrderedPartitionedKVOutput(outputContext, numPartitions); + output.initialize(); + Configuration configAfterMerge = output.conf; + assertEquals("config-from-local-value", configAfterMerge.get("config-from-local")); + assertEquals("config-from-payload-value", configAfterMerge.get("config-from-payload")); + } + @Test(timeout = 10000) public void testClose() throws Exception { - OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir); + OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, conf, workingDir); int numPartitions = 10; OrderedPartitionedKVOutput output = new OrderedPartitionedKVOutput(outputContext, numPartitions); output.initialize(); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedKVOutput2.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedKVOutput2.java index 792b03f572..a52788e716 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedKVOutput2.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedKVOutput2.java @@ -93,9 +93,24 @@ public void testNonStartedOutput() throws Exception { } } + @Test(timeout = 5000) + public void testConfigMerge() throws Exception { + Configuration localConf = new Configuration(conf); + localConf.set("config-from-local", "config-from-local-value"); + Configuration payload = new Configuration(false); + payload.set("config-from-payload", "config-from-payload-value"); + OutputContext outputContext = OutputTestHelpers.createOutputContext(localConf, payload, workingDir); + int numPartitions = 10; + UnorderedKVOutput output = new UnorderedKVOutput(outputContext, numPartitions); + output.initialize(); + Configuration configAfterMerge = output.conf; + assertEquals("config-from-local-value", configAfterMerge.get("config-from-local")); + assertEquals("config-from-payload-value", configAfterMerge.get("config-from-payload")); + } + @Test(timeout = 10000) public void testClose() throws Exception { - OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, workingDir); + OutputContext outputContext = OutputTestHelpers.createOutputContext(conf, conf, workingDir); int numPartitions = 1; UnorderedKVOutput output = new UnorderedKVOutput(outputContext, numPartitions); output.initialize(); diff --git a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedPartitionedKVOutput2.java b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedPartitionedKVOutput2.java index eec4bf59e3..52e06300dd 100644 --- a/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedPartitionedKVOutput2.java +++ b/tez-runtime-library/src/test/java/org/apache/tez/runtime/library/output/TestUnorderedPartitionedKVOutput2.java @@ -22,6 +22,8 @@ import java.util.List; import com.google.protobuf.ByteString; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; import org.apache.tez.common.TezCommonUtils; import org.apache.tez.common.TezUtilsInternal; import org.apache.tez.runtime.api.Event; @@ -59,4 +61,21 @@ public void testNonStartedOutput() throws Exception { assertTrue(emptyPartionsBitSet.get(i)); } } + + @Test + public void testConfigMerge() throws Exception { + Configuration userPayloadConf = new Configuration(false); + Configuration baseConf = new Configuration(false); + + userPayloadConf.set("local-key", "local-value"); + baseConf.set("base-key", "base-value"); + OutputContext outputContext = OutputTestHelpers.createOutputContext( + userPayloadConf, baseConf, new Path("/")); + UnorderedPartitionedKVOutput output = + new UnorderedPartitionedKVOutput(outputContext, 1); + output.initialize(); + Configuration mergedConf = output.conf; + assertEquals("base-value", mergedConf.get("base-key")); + assertEquals("local-value", mergedConf.get("local-key")); + } }