diff --git a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/CallerContext.java b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/CallerContext.java index 9f9c9741f36bb..b5f276004cd87 100644 --- a/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/CallerContext.java +++ b/hadoop-common-project/hadoop-common/src/main/java/org/apache/hadoop/ipc/CallerContext.java @@ -43,6 +43,14 @@ @InterfaceStability.Evolving public final class CallerContext { public static final Charset SIGNATURE_ENCODING = StandardCharsets.UTF_8; + + /** + * The illegal characters include '\t', '\n', '='. + * User should not set illegal characters to the context. + */ + private static final Set ILLEGAL_CHARACTERS = + Collections.unmodifiableSet(new HashSet<>(Arrays.asList("\t", "\n", "="))); + /** The caller context. * * It will be truncated if it exceeds the maximum allowed length in @@ -74,9 +82,24 @@ public byte[] getSignature() { null : Arrays.copyOf(signature, signature.length); } + /** + * Whether the context is valid. + * The context should not contain '\t', '\n', '='. + * Because the context could be written to audit log. + */ @InterfaceAudience.Private public boolean isContextValid() { - return context != null && !context.isEmpty(); + if (context == null || context.isEmpty()) { + return false; + } + + for (String str: ILLEGAL_CHARACTERS) { + if (context.contains(str)) { + return false; + } + } + + return true; } @Override @@ -117,13 +140,6 @@ public String toString() { /** The caller context builder. */ public static final class Builder { private static final String KEY_VALUE_SEPARATOR = ":"; - /** - * The illegal separators include '\t', '\n', '='. - * User should not set illegal separator. - */ - private static final Set ILLEGAL_SEPARATORS = - Collections.unmodifiableSet( - new HashSet<>(Arrays.asList("\t", "\n", "="))); private final String fieldSeparator; private final StringBuilder sb = new StringBuilder(); private byte[] signature; @@ -156,7 +172,7 @@ public Builder(String context, String separator) { * @param separator the separator of fields. */ private void checkFieldSeparator(String separator) { - if (ILLEGAL_SEPARATORS.contains(separator)) { + if (ILLEGAL_CHARACTERS.contains(separator)) { throw new IllegalArgumentException("Illegal field separator: " + separator); } @@ -164,8 +180,6 @@ private void checkFieldSeparator(String separator) { /** * Whether the field is valid. - * The field should not contain '\t', '\n', '='. - * Because the context could be written to audit log. * @param field one of the fields in context. * @return true if the field is not null or empty. */ diff --git a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestCallerContext.java b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestCallerContext.java index bb4a119e7db29..259ca839444af 100644 --- a/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestCallerContext.java +++ b/hadoop-common-project/hadoop-common/src/test/java/org/apache/hadoop/ipc/TestCallerContext.java @@ -31,8 +31,7 @@ public void testBuilderAppend() { CallerContext.Builder builder = new CallerContext.Builder(null, conf); CallerContext context = builder.append("context1") .append("context2").append("key3", "value3").build(); - Assert.assertEquals(true, - context.getContext().contains("$")); + Assert.assertTrue(context.getContext().contains("$")); String[] items = context.getContext().split("\\$"); Assert.assertEquals(3, items.length); Assert.assertEquals("key3:value3", items[2]); @@ -74,11 +73,24 @@ public void testBuilderAppendIfAbsent() { } @Test(expected = IllegalArgumentException.class) - public void testNewBuilder() { + public void testIllegalSeparator() { Configuration conf = new Configuration(); // Set illegal separator. conf.set(HADOOP_CALLER_CONTEXT_SEPARATOR_KEY, "\t"); CallerContext.Builder builder = new CallerContext.Builder(null, conf); builder.build(); } + + @Test + public void testValidateCallerContext() { + // CallerContext should not contain '\t', '\n', '='. + CallerContext context = new CallerContext.Builder("context1\ncontext2").build(); + Assert.assertFalse(context.isContextValid()); + + context = new CallerContext.Builder("context1\tcontext2").build(); + Assert.assertFalse(context.isContextValid()); + + context = new CallerContext.Builder("context1=context2").build(); + Assert.assertFalse(context.isContextValid()); + } }