diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java index e184e6a28d..f4a31feec3 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java @@ -31,7 +31,6 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ import org.springframework.test.context.TestContext; import org.springframework.test.context.TestExecutionListener; import org.springframework.test.context.support.AbstractTestExecutionListener; -import org.springframework.test.util.MetaAnnotationUtils; import org.springframework.test.web.servlet.MockMvc; /** @@ -61,10 +60,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut */ @Override public void beforeTestMethod(TestContext testContext) { - TestSecurityContext testSecurityContext = createTestSecurityContext(testContext.getTestMethod(), testContext); - if (testSecurityContext == null) { - testSecurityContext = createTestSecurityContext(testContext.getTestClass(), testContext); - } + TestSecurityContext testSecurityContext = findTestSecurityContext(testContext); if (testSecurityContext == null) { return; } @@ -77,6 +73,21 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut } } + private TestSecurityContext findTestSecurityContext(TestContext testContext) { + TestSecurityContext testSecurityContext = createTestSecurityContext(testContext.getTestMethod(), testContext); + if (testSecurityContext != null) { + return testSecurityContext; + } + for (Class classToSearch = testContext.getTestClass(); classToSearch != null; classToSearch = classToSearch + .getEnclosingClass()) { + testSecurityContext = createTestSecurityContext(classToSearch, testContext); + if (testSecurityContext != null) { + return testSecurityContext; + } + } + return null; + } + /** * If configured before test execution sets the SecurityContext * @since 5.1 @@ -97,10 +108,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut } private TestSecurityContext createTestSecurityContext(Class annotated, TestContext context) { - MetaAnnotationUtils.AnnotationDescriptor withSecurityContextDescriptor = MetaAnnotationUtils - .findAnnotationDescriptor(annotated, WithSecurityContext.class); - WithSecurityContext withSecurityContext = (withSecurityContextDescriptor != null) - ? withSecurityContextDescriptor.getAnnotation() : null; + WithSecurityContext withSecurityContext = AnnotationUtils.findAnnotation(annotated, WithSecurityContext.class); return createTestSecurityContext(annotated, withSecurityContext, context); } diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java index e4849015f3..e29303ffbe 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExcecutionListenerTests.java @@ -92,6 +92,30 @@ public class WithSecurityContextTestExcecutionListenerTests { assertThat(TestSecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("user"); } + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + public void beforeTestMethodInnerClass() throws Exception { + Class testClass = OuterClass.InnerClass.class; + Method testNoAnnotation = ReflectionUtils.findMethod(testClass, "testNoAnnotation"); + given(this.testContext.getTestClass()).willReturn(testClass); + given(this.testContext.getTestMethod()).willReturn(testNoAnnotation); + given(this.testContext.getApplicationContext()).willThrow(new IllegalStateException("")); + this.listener.beforeTestMethod(this.testContext); + assertThat(TestSecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("user"); + } + + @Test + @SuppressWarnings({ "rawtypes", "unchecked" }) + public void beforeTestMethodInnerInnerClass() throws Exception { + Class testClass = OuterClass.InnerClass.InnerInnerClass.class; + Method testNoAnnotation = ReflectionUtils.findMethod(testClass, "testNoAnnotation"); + given(this.testContext.getTestClass()).willReturn(testClass); + given(this.testContext.getTestMethod()).willReturn(testNoAnnotation); + given(this.testContext.getApplicationContext()).willThrow(new IllegalStateException("")); + this.listener.beforeTestMethod(this.testContext); + assertThat(TestSecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("user"); + } + // gh-3962 @Test public void withSecurityContextAfterSqlScripts() { @@ -166,4 +190,23 @@ public class WithSecurityContextTestExcecutionListenerTests { } + @WithMockUser + static class OuterClass { + + static class InnerClass { + + void testNoAnnotation() { + } + + static class InnerInnerClass { + + void testNoAnnotation() { + } + + } + + } + + } + }