diff --git a/src/main/java/org/apache/xml/security/signature/XMLSignature.java b/src/main/java/org/apache/xml/security/signature/XMLSignature.java index b2ec541e5..658bcaf37 100644 --- a/src/main/java/org/apache/xml/security/signature/XMLSignature.java +++ b/src/main/java/org/apache/xml/security/signature/XMLSignature.java @@ -684,11 +684,7 @@ private void setSignatureValueElement(byte[] bytes) { signatureValueElement.removeChild(signatureValueElement.getFirstChild()); } - String base64codedValue = XMLUtils.encodeToString(bytes); - - if (base64codedValue.length() > 76 && !XMLUtils.ignoreLineBreaks()) { - base64codedValue = "\n" + base64codedValue + "\n"; - } + String base64codedValue = XMLUtils.encodeElementValue(bytes); Text t = createText(base64codedValue); signatureValueElement.appendChild(t); diff --git a/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java b/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java index efa2fa5a8..7e5e5e293 100644 --- a/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java +++ b/src/main/java/org/apache/xml/security/stax/impl/processor/output/AbstractEncryptOutputProcessor.java @@ -24,12 +24,7 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.security.spec.AlgorithmParameterSpec; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Deque; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; import javax.crypto.Cipher; @@ -40,7 +35,6 @@ import javax.xml.stream.XMLStreamConstants; import javax.xml.stream.XMLStreamException; -import org.apache.commons.codec.binary.Base64OutputStream; import org.apache.xml.security.algorithms.JCEMapper; import org.apache.xml.security.encryption.XMLCipherUtil; import org.apache.xml.security.exceptions.XMLSecurityException; @@ -175,12 +169,7 @@ public void init(OutputProcessorChain outputProcessorChain) throws XMLSecurityEx symmetricCipher.init(Cipher.ENCRYPT_MODE, encryptionPartDef.getSymmetricKey(), parameterSpec); characterEventGeneratorOutputStream = new CharacterEventGeneratorOutputStream(); - Base64OutputStream base64EncoderStream = null; //NOPMD - if (XMLUtils.isIgnoreLineBreaks()) { - base64EncoderStream = new Base64OutputStream(characterEventGeneratorOutputStream, true, 0, null); - } else { - base64EncoderStream = new Base64OutputStream(characterEventGeneratorOutputStream, true); - } + OutputStream base64EncoderStream = XMLUtils.encodeStream(characterEventGeneratorOutputStream); //NOPMD base64EncoderStream.write(iv); OutputStream outputStream = new CipherOutputStream(base64EncoderStream, symmetricCipher); //NOPMD diff --git a/src/main/java/org/apache/xml/security/utils/ElementProxy.java b/src/main/java/org/apache/xml/security/utils/ElementProxy.java index 7e7828f2f..298fbbe01 100644 --- a/src/main/java/org/apache/xml/security/utils/ElementProxy.java +++ b/src/main/java/org/apache/xml/security/utils/ElementProxy.java @@ -313,9 +313,7 @@ public void addTextElement(String text, String localname) { */ public void addBase64Text(byte[] bytes) { if (bytes != null) { - Text t = XMLUtils.ignoreLineBreaks() - ? createText(XMLUtils.encodeToString(bytes)) - : createText("\n" + XMLUtils.encodeToString(bytes) + "\n"); + Text t = createText(XMLUtils.encodeElementValue(bytes)); appendSelf(t); } } diff --git a/src/main/java/org/apache/xml/security/utils/XMLUtils.java b/src/main/java/org/apache/xml/security/utils/XMLUtils.java index 9027469cd..18170ea19 100644 --- a/src/main/java/org/apache/xml/security/utils/XMLUtils.java +++ b/src/main/java/org/apache/xml/security/utils/XMLUtils.java @@ -56,14 +56,63 @@ /** * DOM and XML accessibility and comfort functions. * + * @implNote + * Following system properties affect XML formatting: + * */ public final class XMLUtils { + private static final Logger LOG = System.getLogger(XMLUtils.class.getName()); + + private static final String IGNORE_LINE_BREAKS_PROP = "org.apache.xml.security.ignoreLineBreaks"; + private static final String BASE64_IGNORE_LINE_BREAKS_PROP = "org.apache.xml.security.base64.ignoreLineBreaks"; + private static final String BASE64_LINE_SEPARATOR_PROP = "org.apache.xml.security.base64.lineSeparator"; + private static final String BASE64_LINE_LENGTH_PROP = "org.apache.xml.security.base64.lineLength"; + private static boolean ignoreLineBreaks = AccessController.doPrivileged( - (PrivilegedAction) () -> Boolean.getBoolean("org.apache.xml.security.ignoreLineBreaks")); + (PrivilegedAction) () -> Boolean.getBoolean(IGNORE_LINE_BREAKS_PROP)); + + private static Base64FormattingOptions base64Formatting = + AccessController.doPrivileged((PrivilegedAction) () -> { + Base64FormattingOptions options = new Base64FormattingOptions(); + options.setIgnoreLineBreaks(Boolean.getBoolean(BASE64_IGNORE_LINE_BREAKS_PROP)); + + String lineSeparator = System.getProperty(BASE64_LINE_SEPARATOR_PROP); + if (lineSeparator != null) { + try { + options.setLineSeparator(Base64LineSeparator.valueOf(lineSeparator.toUpperCase())); + } catch (IllegalArgumentException e) { + LOG.log(Level.WARNING, "Illegal value of {0} property ignored: {1}", + BASE64_LINE_SEPARATOR_PROP, lineSeparator); + } + } - private static final Logger LOG = System.getLogger(XMLUtils.class.getName()); + Integer lineLength = Integer.getInteger(BASE64_LINE_LENGTH_PROP); + if (lineLength != null && lineLength >= 4) { + options.setLineLength(lineLength); + } else if (lineLength != null) { + LOG.log(Level.WARNING, "Illegal value of {0} property ignored: {1}", + BASE64_LINE_LENGTH_PROP, lineLength); + } + + return options; + }); + + private static Base64.Encoder base64Encoder = (ignoreLineBreaks || base64Formatting.isIgnoreLineBreaks()) ? + Base64.getEncoder() : + Base64.getMimeEncoder(base64Formatting.getLineLength(), base64Formatting.getLineSeparator().getBytes()); + + private static Base64.Decoder base64Decoder = Base64.getMimeDecoder(); private static XMLParser xmlParserImpl = AccessController.doPrivileged( @@ -515,18 +564,48 @@ public static void addReturnBeforeChild(Element e, Node child) { } public static String encodeToString(byte[] bytes) { - if (ignoreLineBreaks) { - return Base64.getEncoder().encodeToString(bytes); + return base64Encoder.encodeToString(bytes); + } + + /** + * Encodes bytes using Base64, with or without line breaks, depending on configuration (see {@link XMLUtils}). + * @param bytes Bytes to encode + * @return Base64 string + */ + public static String encodeElementValue(byte[] bytes) { + String encoded = encodeToString(bytes); + if (!ignoreLineBreaks && !base64Formatting.isIgnoreLineBreaks() + && encoded.length() > base64Formatting.getLineLength()) { + encoded = "\n" + encoded + "\n"; } - return Base64.getMimeEncoder().encodeToString(bytes); + return encoded; + } + + /** + * Wraps output stream for Base64 encoding. + * Output data may contain line breaks or not, depending on configuration (see {@link XMLUtils}) + * @param stream The underlying output stream to write Base64-encoded data + * @return Stream which writes binary data using Base64 encoder + */ + public static OutputStream encodeStream(OutputStream stream) { + return base64Encoder.wrap(stream); } public static byte[] decode(String encodedString) { - return Base64.getMimeDecoder().decode(encodedString); + return base64Decoder.decode(encodedString); } public static byte[] decode(byte[] encodedBytes) { - return Base64.getMimeDecoder().decode(encodedBytes); + return base64Decoder.decode(encodedBytes); + } + + /** + * Wraps input stream for Base64 decoding. + * @param stream Input stream with Base64-encoded data + * @return Input stream with decoded binary data + */ + public static InputStream decodeStream(InputStream stream) { + return base64Decoder.wrap(stream); } public static boolean isIgnoreLineBreaks() { @@ -1068,4 +1147,52 @@ public static byte[] getBytes(BigInteger big, int bitlen) { return resizedBytes; } + + /** + * Aggregates formatting options for base64Binary values. + */ + static class Base64FormattingOptions { + private boolean ignoreLineBreaks = false; + private Base64LineSeparator lineSeparator = Base64LineSeparator.CRLF; + private int lineLength = 76; + + public boolean isIgnoreLineBreaks() { + return ignoreLineBreaks; + } + + public void setIgnoreLineBreaks(boolean ignoreLineBreaks) { + this.ignoreLineBreaks = ignoreLineBreaks; + } + + public Base64LineSeparator getLineSeparator() { + return lineSeparator; + } + + public void setLineSeparator(Base64LineSeparator lineSeparator) { + this.lineSeparator = lineSeparator; + } + + public int getLineLength() { + return lineLength; + } + + public void setLineLength(int lineLength) { + this.lineLength = lineLength; + } + } + + enum Base64LineSeparator { + CRLF(new byte[]{'\r', '\n'}), + LF(new byte[]{'\n'}); + + private byte[] bytes; + + Base64LineSeparator(byte[] bytes) { + this.bytes = bytes; + } + + public byte[] getBytes() { + return bytes; + } + } } diff --git a/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java b/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java new file mode 100644 index 000000000..b401291e9 --- /dev/null +++ b/src/test/java/org/apache/xml/security/utils/XMLUtilsTest.java @@ -0,0 +1,270 @@ +package org.apache.xml.security.utils; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.*; +import java.util.stream.Collectors; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.*; + +public class XMLUtilsTest { + private static final byte[] data = new byte[60]; // long enough for a line break in MIME encoding + + private Properties backup; + private ClassLoader classLoader; + + @BeforeEach + public void createClassLoader() { + /* create custom classloader to reload class in each test */ + ClassLoader parent = getClass().getClassLoader(); + Collection> classesToReload = List.of( + XMLUtils.class, + XMLUtils.Base64FormattingOptions.class, + XMLUtils.Base64LineSeparator.class + ); + classLoader = new ReloadingClassLoader(parent, classesToReload); + ModuleLayer.boot().findModule("org.apache.santuario.xmlsec").orElseThrow() + .addOpens("org.apache.xml.security.parser", classLoader.getUnnamedModule()); + } + + @BeforeEach + public void backupProperties() { + backup = new Properties(); + backup.putAll(System.getProperties()); + } + + @AfterEach + public void restoreProperties() { + System.setProperties(backup); + } + + @Test + public void testAllPropertiesUnset() throws ReflectiveOperationException, IOException { + System.clearProperty("org.apache.xml.security.ignoreLineBreaks"); + System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks"); + System.clearProperty("org.apache.xml.security.base64.lineSeparator"); + System.clearProperty("org.apache.xml.security.base64.lineLength"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, containsString("\r\n")); + OptionalInt maxLineLength = Arrays.stream(encoded.split("\r\n")).mapToInt(String::length).max(); + assertTrue(maxLineLength.isPresent()); + assertEquals(76, maxLineLength.getAsInt()); + + assertThat(elementValue, containsString(encoded)); + assertThat(elementValue, startsWith("\n")); + assertThat(elementValue, endsWith("\n")); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testIgnoreLineBreaksSet() throws ReflectiveOperationException, IOException { + System.setProperty("org.apache.xml.security.ignoreLineBreaks", "true"); + System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks"); + System.clearProperty("org.apache.xml.security.base64.lineSeparator"); + System.clearProperty("org.apache.xml.security.base64.lineLength"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, not(containsString("\r\n"))); + assertThat(encoded, not(containsString("\n"))); + assertThat(elementValue, not(containsString("\r\n"))); + assertThat(elementValue, not(containsString("\n"))); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testIgnoreLineBreaksTakesPrecedence() throws ReflectiveOperationException, IOException { + System.setProperty("org.apache.xml.security.ignoreLineBreaks", "true"); + System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "false"); + System.setProperty("org.apache.xml.security.base64.lineSeparator", "crlf"); + System.setProperty("org.apache.xml.security.base64.lineLength", "40"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, not(containsString("\r\n"))); + assertThat(encoded, not(containsString("\n"))); + assertThat(elementValue, not(containsString("\r\n"))); + assertThat(elementValue, not(containsString("\n"))); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testBase64IgnoreLineBreaksSet() throws ReflectiveOperationException, IOException { + System.clearProperty("org.apache.xml.security.ignoreLineBreaks"); + System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "true"); + System.clearProperty("org.apache.xml.security.base64.lineSeparator"); + System.clearProperty("org.apache.xml.security.base64.lineLength"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, not(containsString("\r\n"))); + assertThat(encoded, not(containsString("\n"))); + assertThat(elementValue, not(containsString("\r\n"))); + assertThat(elementValue, not(containsString("\n"))); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testBase64IgnoreLineBreaksTakesPrecedence() throws ReflectiveOperationException, IOException { + System.clearProperty("org.apache.xml.security.ignoreLineBreaks"); + System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "true"); + System.setProperty("org.apache.xml.security.base64.lineSeparator", "crlf"); + System.setProperty("org.apache.xml.security.base64.lineLength", "40"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, not(containsString("\r\n"))); + assertThat(encoded, not(containsString("\n"))); + assertThat(elementValue, not(containsString("\r\n"))); + assertThat(elementValue, not(containsString("\n"))); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testBase64CustomFormatting() throws ReflectiveOperationException, IOException { + System.clearProperty("org.apache.xml.security.ignoreLineBreaks"); + System.clearProperty("org.apache.xml.security.base64.ignoreLineBreaks"); + System.setProperty("org.apache.xml.security.base64.lineSeparator", "lf"); + System.setProperty("org.apache.xml.security.base64.lineLength", "40"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, not(containsString("\r\n"))); + assertThat(encoded, containsString("\n")); + OptionalInt maxLineLength = Arrays.stream(encoded.split("\n")).mapToInt(String::length).max(); + assertTrue(maxLineLength.isPresent()); + assertEquals(40, maxLineLength.getAsInt()); + + assertThat(elementValue, containsString(encoded)); + assertThat(elementValue, startsWith("\n")); + assertThat(elementValue, endsWith("\n")); + + assertEquals(encoded, encodedWithStream); + } + + @Test + public void testIllegalPropertiesAreIgnored() throws ReflectiveOperationException, IOException { + System.setProperty("org.apache.xml.security.ignoreLineBreaks", "illegal"); + System.setProperty("org.apache.xml.security.base64.ignoreLineBreaks", "illegal"); + System.setProperty("org.apache.xml.security.base64.lineSeparator", "illegal"); + System.setProperty("org.apache.xml.security.base64.lineLength", "illegal"); + + Class xmlUtilsClass = classLoader.loadClass(XMLUtils.class.getName()); + String encoded = encodeToString(xmlUtilsClass, data); + String elementValue = encodeElementValue(xmlUtilsClass, data); + String encodedWithStream = encodeUsingStream(xmlUtilsClass, data); + + assertThat(encoded, containsString("\r\n")); + OptionalInt maxLineLength = Arrays.stream(encoded.split("\r\n")).mapToInt(String::length).max(); + assertTrue(maxLineLength.isPresent()); + assertEquals(76, maxLineLength.getAsInt()); + + assertThat(elementValue, containsString(encoded)); + assertThat(elementValue, startsWith("\n")); + assertThat(elementValue, endsWith("\n")); + + assertEquals(encoded, encodedWithStream); + } + + private String encodeToString(Class xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException { + return (String) xmlUtilsClass.getMethod("encodeToString", byte[].class).invoke(null, (Object) bytes); + } + + private String encodeElementValue(Class xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException { + return (String) xmlUtilsClass.getMethod("encodeElementValue", byte[].class).invoke(null, (Object) bytes); + } + + private OutputStream encodeStream(Class xmlUtilsClass, OutputStream stream) throws ReflectiveOperationException { + return (OutputStream) xmlUtilsClass.getMethod("encodeStream", OutputStream.class).invoke(null, stream); + } + + private String encodeUsingStream(Class xmlUtilsClass, byte[] bytes) throws ReflectiveOperationException, IOException { + try (ByteArrayOutputStream encoded = new ByteArrayOutputStream(); + OutputStream raw = encodeStream(xmlUtilsClass, encoded)) { + raw.write(bytes); + raw.flush(); + return encoded.toString(StandardCharsets.US_ASCII); + } + } + + private static class ReloadingClassLoader extends ClassLoader { + private Collection classNames; + + public ReloadingClassLoader(ClassLoader parent, Collection> classes) { + super("TestClassLoader", parent); + this.classNames = classes.stream().map(Class::getName).collect(Collectors.toSet()); + } + + @Override + protected Class loadClass(String name, boolean resolve) throws ClassNotFoundException { + if (classNames.contains(name)) { + Class clazz = findClass(name); + if (resolve) { + resolveClass(clazz); + } + return clazz; + } + return super.loadClass(name, resolve); + } + + @Override + protected Class findClass(String name) throws ClassNotFoundException { + if (classNames.contains(name)) { + Class parentLoadedClass = getParent().loadClass(name); + String resourceName = synthesizeClassName(parentLoadedClass) + ".class"; + byte[] classData; + try (InputStream in = parentLoadedClass.getResourceAsStream(resourceName)) { + if (in == null) { + throw new ClassNotFoundException("Could not load class " + name); + } + classData = in.readAllBytes(); + } catch (IOException e) { + throw new ClassNotFoundException("Could not load class " + name, e); + } + + return defineClass(name, classData, 0, classData.length); + } + throw new ClassNotFoundException("Class not found: " + name); + } + + private String synthesizeClassName(Class clazz) { + String name = clazz.getSimpleName(); + if (clazz.isMemberClass()) name = synthesizeClassName(clazz.getEnclosingClass()) + "$" + name; + return name; + } + } +}