Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OB3] Provide config for making transport cert optional for the token endpoint #166

Merged
merged 1 commit into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ priority=1
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,9 @@ url_pattern = "/oauth2/authorize*"
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ priority=1
name = "TokenFilter"
class = "com.wso2.openbanking.accelerator.identity.token.TokenFilter"

[tomcat.filter.init_params]
isTransportCertificateMandatory = true

[[tomcat.filter_mapping]]
name = "TokenFilter"
url_pattern = "/token"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class TokenFilter implements Filter {
private static DefaultTokenFilter defaultTokenFilter;
private String clientId = null;
private static List<OBIdentityFilterValidator> validators = new ArrayList<>();
private boolean isTransportCertMandatory;

private static final String BASIC_AUTH_ERROR_MSG = "Unable to find client id in the request. " +
"Invalid Authorization header found.";
Expand All @@ -73,6 +74,15 @@ public class TokenFilter implements Filter {
public void init(FilterConfig filterConfig) {

ServletContext context = filterConfig.getServletContext();

String isTransportCertMandatoryConf = filterConfig.getInitParameter("isTransportCertificateMandatory");
if (isTransportCertMandatoryConf == null) {
// By default, mandating the transport certificate
isTransportCertMandatory = true;
} else {
isTransportCertMandatory = Boolean.parseBoolean(isTransportCertMandatoryConf);
}

context.log("TokenFilter initialized");
}

Expand Down Expand Up @@ -130,6 +140,11 @@ private ServletRequest appendTransportHeader(ServletRequest request, ServletResp
ServletException, IOException, CertificateEncodingException {

if (request instanceof HttpServletRequest) {

if (!isTransportCertMandatory) {
return request;
}

Object certAttribute = request.getAttribute(IdentityCommonConstants.JAVAX_SERVLET_REQUEST_CERTIFICATE);
String x509Certificate = ((HttpServletRequest) request).getHeader(IdentityCommonUtil.getMTLSAuthHeader());
if (new IdentityCommonHelper().isTransportCertAsHeaderEnabled() && x509Certificate != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.lang.reflect.Field;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Base64;
Expand Down Expand Up @@ -106,7 +107,8 @@ public void nonRegulatoryAppWithAuthorizationHeaderTest() throws Exception {
}

@Test(description = "Test the certificate in context/header is mandated")
public void noCertificateTest() throws IOException, OpenBankingException, ServletException {
public void noCertificateTest() throws IOException, OpenBankingException, ServletException, NoSuchFieldException,
IllegalAccessException {

Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
Expand All @@ -120,6 +122,11 @@ public void noCertificateTest() throws IOException, OpenBankingException, Servle
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader())
.thenReturn(IdentityCommonConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand All @@ -129,7 +136,8 @@ public void noCertificateTest() throws IOException, OpenBankingException, Servle
}

@Test(description = "Test the certificate in attribute is present if config is disabled")
public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBankingException, ServletException {
public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {

Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
Expand All @@ -144,6 +152,11 @@ public void certificateIsNotPresentInAttributeTest() throws IOException, OpenBan
Mockito.doReturn(new DefaultTokenFilter()).when(filter).getDefaultTokenFilter();
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
* Test constants.
*/
public class TestConstants {
public static final String IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME = "isTransportCertMandatory";
public static final String TARGET_STREAM = "targetStream";
public static final String CERTIFICATE_HEADER = "x-wso2-mutual-auth-cert";
public static final String EXPIRED_CERTIFICATE_CONTENT = "-----BEGIN CERTIFICATE-----" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.testng.annotations.Test;

import java.io.IOException;
import java.lang.reflect.Field;
import java.security.cert.X509Certificate;
import java.util.ArrayList;
import java.util.HashMap;
Expand Down Expand Up @@ -123,12 +124,17 @@ public void certificateAttributeValidation() throws Exception {
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);
PowerMockito.when(IdentityCommonUtil.getCertificateFromAttribute(cert)).thenReturn(cert);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
assertEquals(response.getStatus(), HttpServletResponse.SC_OK);
}

@Test(description = "Test whether the certificate header is present")
public void noCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException {
public void noCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
configMap.put(IdentityCommonConstants.ENABLE_TRANSPORT_CERT_AS_HEADER, true);
Expand All @@ -148,6 +154,10 @@ public void noCertificateHeaderValidation() throws IOException, OpenBankingExcep
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand All @@ -159,7 +169,8 @@ public void noCertificateHeaderValidation() throws IOException, OpenBankingExcep


@Test(description = "Test the certificate in attribute is passed as a header")
public void certificateIsPresentInAttributeTest() throws IOException, OpenBankingException, ServletException {
public void certificateIsPresentInAttributeTest() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
MTLSEnforcementValidator mtlsEnforcementValidator = Mockito.spy(MTLSEnforcementValidator.class);
PowerMockito.mockStatic(IdentityCommonUtil.class);

Expand All @@ -183,12 +194,17 @@ public void certificateIsPresentInAttributeTest() throws IOException, OpenBankin
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);
PowerMockito.when(IdentityCommonUtil.getCertificateFromAttribute(cert)).thenReturn(cert);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
assertEquals(response.getStatus(), HttpServletResponse.SC_OK);
}

@Test(description = "Test whether the certificate attribute is valid")
public void invalidCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException {
public void invalidCertificateHeaderValidation() throws IOException, OpenBankingException, ServletException,
NoSuchFieldException, IllegalAccessException {
Map<String, Object> configMap = new HashMap<>();
PowerMockito.mockStatic(IdentityCommonUtil.class);
configMap.put(IdentityCommonConstants.ENABLE_TRANSPORT_CERT_AS_HEADER, true);
Expand All @@ -209,6 +225,10 @@ public void invalidCertificateHeaderValidation() throws IOException, OpenBanking
PowerMockito.when(IdentityCommonUtil.getRegulatoryFromSPMetaData("test")).thenReturn(true);
PowerMockito.when(IdentityCommonUtil.getMTLSAuthHeader()).thenReturn(TestConstants.CERTIFICATE_HEADER);

Field privateField = TokenFilter.class.getDeclaredField(TestConstants.IS_TRANSPORT_CERT_MANDATORY_FIELD_NAME);
privateField.setAccessible(true);
privateField.set(filter, true);

filter.doFilter(request, response, filterChain);
Map<String, String> responseMap = TestUtil.getResponse(response.getOutputStream());
assertEquals(response.getStatus(), HttpStatus.SC_BAD_REQUEST);
Expand Down
Loading