Skip to content

Commit

Permalink
Apply method predicate before searching type hierarchy
Browse files Browse the repository at this point in the history
This is a proof-of-concept fix for junit-team#3498.
  • Loading branch information
sbrannen committed Oct 9, 2023
1 parent f9ae7f6 commit 82e2852
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 15 deletions.
28 changes: 28 additions & 0 deletions junit-jupiter-engine/src/test/java/a/A.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package a;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.junit.jupiter.api.BeforeAll;

public abstract class A {

public static final List<String> invocations = Collections.synchronizedList(new ArrayList<>());

@BeforeAll
static void before() {
invocations.add("A.before()");
}

}
38 changes: 38 additions & 0 deletions junit-jupiter-engine/src/test/java/b/B.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright 2015-2023 the original author or authors.
*
* All rights reserved. This program and the accompanying materials are
* made available under the terms of the Eclipse Public License v2.0 which
* accompanies this distribution and is available at
*
* https://www.eclipse.org/legal/epl-v20.html
*/

package b;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import a.A;

public class B extends A {

@BeforeEach
void before() {
invocations.add("B.before()");
}

@Test
void test() {
invocations.add("B.test()");
}

@AfterAll
static void checkInvocations() {
assertThat(A.invocations).containsExactly("A.before()", "B.before()", "B.test()");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1489,29 +1489,27 @@ public static Stream<Method> streamMethods(Class<?> clazz, Predicate<Method> pre
Preconditions.notNull(predicate, "Predicate must not be null");
Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null");

// @formatter:off
return findAllMethodsInHierarchy(clazz, traversalMode).stream()
.filter(predicate)
.distinct();
// @formatter:on
return findAllMethodsInHierarchy(clazz, predicate, traversalMode).stream().distinct();
}

/**
* Find all non-synthetic methods in the superclass and interface hierarchy,
* excluding Object.
* excluding Object, that match the specified {@code predicate}.
*/
private static List<Method> findAllMethodsInHierarchy(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Method> findAllMethodsInHierarchy(Class<?> clazz, Predicate<Method> predicate,
HierarchyTraversalMode traversalMode) {

Preconditions.notNull(clazz, "Class must not be null");
Preconditions.notNull(traversalMode, "HierarchyTraversalMode must not be null");

// @formatter:off
List<Method> localMethods = getDeclaredMethods(clazz, traversalMode).stream()
.filter(method -> !method.isSynthetic())
.filter(predicate.and(method -> !method.isSynthetic()))
.collect(toList());
List<Method> superclassMethods = getSuperclassMethods(clazz, traversalMode).stream()
List<Method> superclassMethods = getSuperclassMethods(clazz, predicate, traversalMode).stream()
.filter(method -> !isMethodShadowedByLocalMethods(method, localMethods))
.collect(toList());
List<Method> interfaceMethods = getInterfaceMethods(clazz, traversalMode).stream()
List<Method> interfaceMethods = getInterfaceMethods(clazz, predicate, traversalMode).stream()
.filter(method -> !isMethodShadowedByLocalMethods(method, localMethods))
.collect(toList());
// @formatter:on
Expand Down Expand Up @@ -1647,16 +1645,18 @@ private static int defaultMethodSorter(Method method1, Method method2) {
return comparison;
}

private static List<Method> getInterfaceMethods(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Method> getInterfaceMethods(Class<?> clazz, Predicate<Method> predicate,
HierarchyTraversalMode traversalMode) {

List<Method> allInterfaceMethods = new ArrayList<>();
for (Class<?> ifc : clazz.getInterfaces()) {

// @formatter:off
List<Method> localInterfaceMethods = getMethods(ifc).stream()
.filter(m -> !isAbstract(m))
.filter(predicate.and(method -> !isAbstract(method)))
.collect(toList());

List<Method> superinterfaceMethods = getInterfaceMethods(ifc, traversalMode).stream()
List<Method> superinterfaceMethods = getInterfaceMethods(ifc, predicate, traversalMode).stream()
.filter(method -> !isMethodShadowedByLocalMethods(method, localInterfaceMethods))
.collect(toList());
// @formatter:on
Expand Down Expand Up @@ -1706,12 +1706,14 @@ private static boolean isFieldShadowedByLocalFields(Field field, List<Field> loc
return localFields.stream().anyMatch(local -> local.getName().equals(field.getName()));
}

private static List<Method> getSuperclassMethods(Class<?> clazz, HierarchyTraversalMode traversalMode) {
private static List<Method> getSuperclassMethods(Class<?> clazz, Predicate<Method> predicate,
HierarchyTraversalMode traversalMode) {

Class<?> superclass = clazz.getSuperclass();
if (!isSearchable(superclass)) {
return Collections.emptyList();
}
return findAllMethodsInHierarchy(superclass, traversalMode);
return findAllMethodsInHierarchy(superclass, predicate, traversalMode);
}

private static boolean isMethodShadowedByLocalMethods(Method method, List<Method> localMethods) {
Expand Down

0 comments on commit 82e2852

Please sign in to comment.