Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Signed-off-by: asjervanasten <[email protected]>
  • Loading branch information
appiepollo14 committed Dec 29, 2023
1 parent 9fd53f9 commit 9ea8a5f
Show file tree
Hide file tree
Showing 10 changed files with 229 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Comparator;

import jakarta.ws.rs.ApplicationPath;
import jakarta.ws.rs.core.Application;
import jakarta.ws.rs.ext.MessageBodyReader;
import jakarta.ws.rs.ext.MessageBodyWriter;

import org.apache.cxf.jaxrs.client.JAXRSClientFactoryBean;
import org.junit.platform.commons.support.AnnotationSupport;
import org.junit.platform.commons.support.ReflectionSupport;
Expand All @@ -55,8 +55,8 @@ public class RestClientBuilder {

/**
* @param appContextRoot The protocol, hostname, port, and application root path for the REST Client
* For example, <code>http://localhost:8080/myapp/</code>. If unspecified, the app context
* root will be automatically detected by {@link ApplicationEnvironment#getApplicationURL()}
* For example, <code>http://localhost:8080/myapp/</code>. If unspecified, the app context
* root will be automatically detected by {@link ApplicationEnvironment#getApplicationURL()}
* @return The same builder instance
*/
public RestClientBuilder withAppContextRoot(String appContextRoot) {
Expand All @@ -67,9 +67,9 @@ public RestClientBuilder withAppContextRoot(String appContextRoot) {

/**
* @param jaxrsPath The portion of the path after the app context root. For example, if a JAX-RS
* endpoint is deployed at <code>http://localhost:8080/myapp/hello</code> and the app context root
* is <code>http://localhost:8080/myapp/</code>, then the jaxrsPath is <code>hello</code>. If
* unspecified, the JAX-RS path will be automatically detected by annotation scanning.
* endpoint is deployed at <code>http://localhost:8080/myapp/hello</code> and the app context root
* is <code>http://localhost:8080/myapp/</code>, then the jaxrsPath is <code>hello</code>. If
* unspecified, the JAX-RS path will be automatically detected by annotation scanning.
* @return The same builder instance
*/
public RestClientBuilder withJaxrsPath(String jaxrsPath) {
Expand All @@ -93,7 +93,7 @@ public RestClientBuilder withJwt(String jwt) {
}

/**
* @param user The username portion of the Basic auth header
* @param user The username portion of the Basic auth header
* @param password The password portion of the Basic auth header
* @return The same builder instance
*/
Expand All @@ -110,7 +110,7 @@ public RestClientBuilder withBasicAuth(String user, String password) {
}

/**
* @param key The header key
* @param key The header key
* @param value The header value
* @return The same builder instance
*/
Expand All @@ -126,8 +126,8 @@ public RestClientBuilder withHeader(String key, String value) {

/**
* @param providers One or more providers to apply. Providers typically implement
* {@link MessageBodyReader} and/or {@link MessageBodyWriter}. If unspecified,
* the {@link JsonBProvider} will be applied.
* {@link MessageBodyReader} and/or {@link MessageBodyWriter}. If unspecified,
* the {@link JsonBProvider} will be applied.
* @return The same builder instance
*/
public RestClientBuilder withProviders(Class<?>... providers) {
Expand All @@ -145,7 +145,7 @@ public <T> T build(Class<T> clazz) {
providers = Collections.singletonList(JsonBProvider.class);

JAXRSClientFactoryBean bean = new org.apache.cxf.jaxrs.client.JAXRSClientFactoryBean();
String basePath = join(appContextRoot, jaxrsPath);
String basePath = joinAppAndJaxrsPath(appContextRoot, jaxrsPath);
LOG.info("Building rest client for " + clazz + " with base path: " + basePath + " and providers: " + providers);
bean.setResourceClass(clazz);
bean.setProviders(providers);
Expand All @@ -163,10 +163,10 @@ private static String locateApplicationPath(Class<?> clazz) {

// First check for a jakarta.ws.rs.core.Application in the same package as the resource
List<Class<?>> appClasses = ReflectionSupport.findAllClassesInPackage(resourcePackage,
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
if (appClasses.size() == 0) {
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
if (appClasses.isEmpty()) {
LOG.debug("no classes implementing Application found in pkg: " + resourcePackage);
// If not found, check under the 3rd package, so com.foo.bar.*
// Classpath scanning can be expensive, so we jump straight to the 3rd package from root instead
Expand All @@ -176,33 +176,33 @@ private static String locateApplicationPath(Class<?> clazz) {
String checkPkg = pkgs[0] + '.' + pkgs[1] + '.' + pkgs[2];
LOG.debug("checking in pkg: " + checkPkg);
appClasses = ReflectionSupport.findAllClassesInPackage(checkPkg,
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
c -> Application.class.isAssignableFrom(c) &&
AnnotationSupport.isAnnotated(c, ApplicationPath.class),
n -> true);
}
}

if (appClasses.size() == 0) {
if (appClasses.isEmpty()) {
LOG.info("No classes implementing 'jakarta.ws.rs.core.Application' found on classpath to set base path from " + clazz +
". Defaulting base path to '/'");
". Defaulting base path to '/'");
return "";
}

Class<?> selectedClass = appClasses.stream()
.sorted((c1, c2) -> c1.getName().compareTo(c2.getName()))
.findFirst()
.get();
.sorted(Comparator.comparing(Class::getName))
.findFirst()
.get();
ApplicationPath appPath = AnnotationSupport.findAnnotation(selectedClass, ApplicationPath.class).get();
if (appClasses.size() > 1) {
LOG.warn("Found multiple classes implementing 'jakarta.ws.rs.core.Application' on classpath: " + appClasses +
". Setting base path from the first class discovered (" + selectedClass.getCanonicalName() + ") with path: " +
appPath.value());
". Setting base path from the first class discovered (" + selectedClass.getCanonicalName() + ") with path: " +
appPath.value());
}
LOG.debug("Using base ApplicationPath of '" + appPath.value() + "'");
return appPath.value();
}

private static String join(String firstPart, String secondPart) {
private static String joinAppAndJaxrsPath(String firstPart, String secondPart) {
if (firstPart.endsWith("/") && secondPart.startsWith("/"))
return firstPart + secondPart.substring(1);
else if (firstPart.endsWith("/") || secondPart.startsWith("/"))
Expand All @@ -211,4 +211,5 @@ else if (firstPart.endsWith("/") || secondPart.startsWith("/"))
return firstPart + "/" + secondPart;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@
*/
package org.microshed.testing.jupiter;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Properties;

import org.junit.jupiter.api.extension.BeforeAllCallback;
import org.junit.jupiter.api.extension.ExtensionConfigurationException;
import org.junit.jupiter.api.extension.ExtensionContext;
Expand All @@ -44,6 +33,12 @@
import org.microshed.testing.kafka.KafkaConsumerClient;
import org.microshed.testing.kafka.KafkaProducerClient;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.net.URL;
import java.util.*;

/**
* JUnit Jupiter extension that is applied whenever the <code>@MicroProfileTest</code> is used on a test class.
* Currently this is tied to Testcontainers managing runtime build/deployment, but in a future version
Expand Down Expand Up @@ -90,8 +85,8 @@ private static void injectRestClients(Class<?> clazz) {

for (Field restClientField : restClientFields) {
if (!Modifier.isPublic(restClientField.getModifiers()) ||
!Modifier.isStatic(restClientField.getModifiers()) ||
Modifier.isFinal(restClientField.getModifiers())) {
!Modifier.isStatic(restClientField.getModifiers()) ||
Modifier.isFinal(restClientField.getModifiers())) {
throw new ExtensionConfigurationException("REST client field must be public, static, and non-final: " + restClientField);
}
RestClientBuilder rcBuilder = new RestClientBuilder();
Expand Down Expand Up @@ -137,10 +132,10 @@ private static void injectKafkaClients(Class<?> clazz) {
throw new ExtensionConfigurationException("Fields annotated with @KafkaProducerClient must be of the type " + KafkaProducer.getName());
}
if (!Modifier.isPublic(producerField.getModifiers()) ||
!Modifier.isStatic(producerField.getModifiers()) ||
Modifier.isFinal(producerField.getModifiers())) {
!Modifier.isStatic(producerField.getModifiers()) ||
Modifier.isFinal(producerField.getModifiers())) {
throw new ExtensionConfigurationException("The KafkaProducer field annotated with @KafkaProducerClient " +
"must be public, static, and non-final: " + producerField);
"must be public, static, and non-final: " + producerField);
}

Properties properties = kafkaProcessor.getProducerProperties(producerField);
Expand All @@ -159,10 +154,10 @@ private static void injectKafkaClients(Class<?> clazz) {
throw new ExtensionConfigurationException("Fields annotated with @KafkaConsumerClient must be of the type " + KafkaConsumer.getName());
}
if (!Modifier.isPublic(consumerField.getModifiers()) ||
!Modifier.isStatic(consumerField.getModifiers()) ||
Modifier.isFinal(consumerField.getModifiers())) {
!Modifier.isStatic(consumerField.getModifiers()) ||
Modifier.isFinal(consumerField.getModifiers())) {
throw new ExtensionConfigurationException("The KafkaProducer field annotated with @KafkaConsumerClient " +
"must be public, static, and non-final: " + consumerField);
"must be public, static, and non-final: " + consumerField);
}

Properties properties = kafkaProcessor.getConsumerProperties(consumerField);
Expand All @@ -182,7 +177,7 @@ private static void injectKafkaClients(Class<?> clazz) {
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
@SuppressWarnings({"unchecked", "rawtypes"})
private static void configureRestAssured(ApplicationEnvironment config) {
if (!config.configureRestAssured())
return;
Expand Down
10 changes: 6 additions & 4 deletions core/src/main/java/org/microshed/testing/jwt/JwtConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@
*/
package org.microshed.testing.jwt;

import org.junit.jupiter.api.extension.ExtendWith;
import org.microshed.testing.jaxrs.RESTClient;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.microshed.testing.jaxrs.RESTClient;

/**
* Used to annotate a REST Client to configure MicroProfile JWT settings
* that will be applied to all of its HTTP invocations.
* In order for this annotation to have any effect, the field must also
* be annotated with {@link RESTClient}.
*/
@Target({ ElementType.FIELD })
@Target({ElementType.FIELD, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@ExtendWith(JwtConfigExtension.class)
public @interface JwtConfig {

public static final String DEFAULT_ISSUER = "http://testissuer.com";
Expand All @@ -46,7 +48,7 @@
* array of claims in the following format:
* key=value
* example: {"sub=fred", "upn=fred", "kid=123"}
*
* <p>
* For arrays, separate values with a comma.
* example: {"groups=red,green,admin", "sub=fred"}
*
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package org.microshed.testing.jwt;

import org.junit.jupiter.api.extension.*;
import org.microshed.testing.internal.InternalLogger;
import org.microshed.testing.jupiter.MicroShedTestExtension;

import java.lang.reflect.Field;
import java.lang.reflect.Method;

public class JwtConfigExtension implements BeforeTestExecutionCallback, AfterTestExecutionCallback {

private static final InternalLogger LOG = InternalLogger.get(JwtConfigExtension.class);

@Override
public void beforeTestExecution(ExtensionContext context) throws Exception {
configureJwt(context);
}

@Override
public void afterTestExecution(ExtensionContext context) throws Exception {
removeJwt(context);
}

private void configureJwt(ExtensionContext context) throws Exception {

// Check if the test method has the @JwtConfig annotation
Method testMethod = context.getTestMethod().orElse(null);
if (testMethod != null) {

// Check if RestAssured is being used
Class<?> restAssuredClass = tryLoad("io.restassured.RestAssured");
if (restAssuredClass == null) {
LOG.debug("RESTAssured not found!");
return;
}

LOG.debug("RESTAssured found!");

JwtConfig jwtConfig = testMethod.getAnnotation(JwtConfig.class);
if (jwtConfig != null) {
// Configure RestAssured with the values from @JwtConfig for each test method
LOG.info("JWTConfig on method: " + testMethod.getName());
// Get the RequestSpecBuilder class
Class<?> requestSpecBuilderClass = Class.forName("io.restassured.builder.RequestSpecBuilder");
// Create an instance of RequestSpecBuilder
Object requestSpecBuilder = requestSpecBuilderClass.newInstance();
// Get the requestSpecification field
Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification");
requestSpecificationField.setAccessible(true);

// Get the header method of RequestSpecBuilder
Method headerMethod = requestSpecBuilderClass.getDeclaredMethod("addHeader", String.class, String.class);

try {
String jwt = JwtBuilder.buildJwt(jwtConfig.subject(), jwtConfig.issuer(), jwtConfig.claims());
headerMethod.invoke(requestSpecBuilder, "Authorization", "Bearer " + jwt);
LOG.debug("Using provided JWT auth header: " + jwt);
} catch (Exception e) {
throw new ExtensionConfigurationException("Error while building JWT for method " + testMethod.getName() + " with JwtConfig: " + jwtConfig, e);
}

// Set the updated requestSpecification
requestSpecificationField.set(null, requestSpecBuilderClass.getMethod("build").invoke(requestSpecBuilder));
}
}
}

private void removeJwt(ExtensionContext context) throws Exception {
// Check if RestAssured is being used
Class<?> restAssuredClass = tryLoad("io.restassured.RestAssured");
if (restAssuredClass == null) {
LOG.debug("RESTAssured not found!");
return;
}

// Check if the test method has the @JwtConfig annotation
Method testMethod = context.getTestMethod().orElse(null);
if (testMethod != null) {

LOG.debug("Method was annotated with: " + testMethod.getName());
// Get the requestSpecification field
Field requestSpecificationField = restAssuredClass.getDeclaredField("requestSpecification");
requestSpecificationField.setAccessible(true);

// Removes all requestSpec
requestSpecificationField.set(null, null);
}
}

private static Class<?> tryLoad(String clazz) {
try {
return Class.forName(clazz, false, MicroShedTestExtension.class.getClassLoader());
} catch (ClassNotFoundException | LinkageError e) {
return null;
}
}
}
Loading

0 comments on commit 9ea8a5f

Please sign in to comment.