diff --git a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java index 846573289f..16866734e8 100644 --- a/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java +++ b/src/main/java/com/amazon/dlic/auth/http/saml/HTTPSamlAuthenticator.java @@ -184,7 +184,9 @@ public Optional reRequestAuthentication(final SecurityRequest if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { // Verficiation of SAML ASC endpoint only works with RestRequests if (!(request instanceof OpenSearchRequest)) { - throw new SecurityRequestChannelUnsupported(); + throw new SecurityRequestChannelUnsupported( + API_AUTHTOKEN_SUFFIX + " not supported for request of type " + request.getClass().getName() + ); } else { final OpenSearchRequest openSearchRequest = (OpenSearchRequest) request; final RestRequest restRequest = openSearchRequest.breakEncapsulationForRequest(); @@ -200,6 +202,9 @@ public Optional reRequestAuthentication(final SecurityRequest new SecurityResponse(HttpStatus.SC_UNAUTHORIZED, Map.of("WWW-Authenticate", getWwwAuthenticateHeader(saml2Settings)), "") ); } catch (Exception e) { + if (e instanceof SecurityRequestChannelUnsupported) { + throw (SecurityRequestChannelUnsupported) e; + } log.error("Error in reRequestAuthentication()", e); return Optional.empty(); } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequest.java b/src/main/java/org/opensearch/security/filter/SecurityRequest.java index 7e6e94e0a6..ab6f41b354 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequest.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequest.java @@ -25,29 +25,29 @@ public interface SecurityRequest { /** Collection of headers associated with the request */ - public Map> getHeaders(); + Map> getHeaders(); /** The SSLEngine associated with the request */ - public SSLEngine getSSLEngine(); + SSLEngine getSSLEngine(); /** The path of the request */ - public String path(); + String path(); /** The method type of this request */ - public Method method(); + Method method(); /** The remote address of the request, possible null */ - public Optional getRemoteAddress(); + Optional getRemoteAddress(); /** The full uri of the request */ - public String uri(); + String uri(); /** Finds the first value of the matching header or null */ - default public String header(final String headerName) { + default String header(final String headerName) { final Optional>> headersMap = Optional.ofNullable(getHeaders()); return headersMap.map(headers -> headers.get(headerName)).map(List::stream).flatMap(Stream::findFirst).orElse(null); } /** The parameters associated with this request */ - public Map params(); + Map params(); } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java index bcacc2cf7a..fe9c557825 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRequestChannelUnsupported.java @@ -11,7 +11,12 @@ package org.opensearch.security.filter; +import org.opensearch.OpenSearchException; + /** Thrown when a security rest channel is not supported */ -public class SecurityRequestChannelUnsupported extends RuntimeException { +public class SecurityRequestChannelUnsupported extends OpenSearchException { + public SecurityRequestChannelUnsupported(String msg, Object... args) { + super(msg, args); + } } diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java index 263d7500cf..c492656bca 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestFilter.java @@ -30,7 +30,6 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.regex.Matcher; import java.util.regex.Pattern; import javax.net.ssl.SSLPeerUnverifiedException; @@ -48,7 +47,6 @@ import org.opensearch.rest.RestChannel; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.RestRequest.Method; import org.opensearch.security.auditlog.AuditLog; import org.opensearch.security.auditlog.AuditLog.Origin; import org.opensearch.security.auth.BackendRegistry; @@ -299,9 +297,7 @@ public void checkAndAuthenticateRequest(SecurityRequestChannel requestChannel) t return; } - Matcher matcher = PATTERN_PATH_PREFIX.matcher(requestChannel.path()); - final String suffix = matcher.matches() ? matcher.group(2) : null; - if (requestChannel.method() != Method.OPTIONS && !(HEALTH_SUFFIX.equals(suffix)) && !(WHO_AM_I_SUFFIX.equals(suffix))) { + if (!SecurityRestUtils.shouldSkipAuthentication(requestChannel)) { if (!registry.authenticate(requestChannel)) { // another roundtrip org.apache.logging.log4j.ThreadContext.remove("user"); diff --git a/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java index 1599346b90..705fe31ee0 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java +++ b/src/main/java/org/opensearch/security/filter/SecurityRestUtils.java @@ -1,5 +1,24 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + * + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + package org.opensearch.security.filter; +import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX; +import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX; +import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX; + +import java.util.regex.Matcher; + +import org.opensearch.rest.RestRequest.Method; + public class SecurityRestUtils { public static String path(final String uri) { final int index = uri.indexOf('?'); @@ -9,4 +28,15 @@ public static String path(final String uri) { return uri; } } + + public static boolean shouldSkipAuthentication(SecurityRequestChannel request) { + Matcher matcher = PATTERN_PATH_PREFIX.matcher(request.path()); + final String suffix = matcher.matches() ? matcher.group(2) : null; + + boolean shouldSkipAuthentication = (request.method() == Method.OPTIONS) + || HEALTH_SUFFIX.equals(suffix) + || WHO_AM_I_SUFFIX.equals(suffix); + + return shouldSkipAuthentication; + } } diff --git a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java index 5112ceced3..51825e977b 100644 --- a/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java +++ b/src/main/java/org/opensearch/security/ssl/http/netty/Netty4HttpRequestHeaderVerifier.java @@ -10,7 +10,6 @@ import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultHttpRequest; -import io.netty.handler.codec.http.HttpMethod; import io.netty.handler.codec.http.HttpRequest; import io.netty.util.ReferenceCountUtil; import org.opensearch.ExceptionsHelper; @@ -20,7 +19,6 @@ import io.netty.channel.ChannelHandlerContext; import org.opensearch.http.netty4.Netty4HttpChannel; import org.opensearch.http.netty4.Netty4HttpServerTransport; -import org.opensearch.rest.RestUtils; import org.opensearch.security.filter.SecurityRequestChannel; import org.opensearch.security.filter.SecurityRequestChannelUnsupported; import org.opensearch.security.filter.SecurityRequestFactory; @@ -34,12 +32,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.OpenSearchSecurityException; -import java.util.regex.Matcher; - -import static com.amazon.dlic.auth.http.saml.HTTPSamlAuthenticator.API_AUTHTOKEN_SUFFIX; -import static org.opensearch.security.filter.SecurityRestFilter.HEALTH_SUFFIX; -import static org.opensearch.security.filter.SecurityRestFilter.PATTERN_PATH_PREFIX; -import static org.opensearch.security.filter.SecurityRestFilter.WHO_AM_I_SUFFIX; import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE; import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE; import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS; @@ -83,34 +75,21 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro ctx.channel().attr(IS_AUTHENTICATED).set(Boolean.FALSE); final Netty4HttpChannel httpChannel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); - String rawPath = SecurityRestUtils.path(msg.uri()); - String path = RestUtils.decodeComponent(rawPath); - Matcher matcher = PATTERN_PATH_PREFIX.matcher(path); - final String suffix = matcher.matches() ? matcher.group(2) : null; - if (API_AUTHTOKEN_SUFFIX.equals(suffix)) { - ctx.fireChannelRead(msg); - return; - } final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(msg, httpChannel); ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { injectUser(msg, threadContext); - boolean shouldSkipAuthentication = HttpMethod.OPTIONS.equals(msg.method()) - || HEALTH_SUFFIX.equals(suffix) - || WHO_AM_I_SUFFIX.equals(suffix); - - if (!shouldSkipAuthentication) { - // If request channel is completed and a response is sent, then there was a failure during authentication - restFilter.checkAndAuthenticateRequest(requestChannel); - } + // If request channel is completed and a response is sent, then there was a failure during authentication + restFilter.checkAndAuthenticateRequest(requestChannel); ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false); ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore); requestChannel.getQueuedResponse().ifPresent(response -> ctx.channel().attr(EARLY_RESPONSE).set(response)); + boolean shouldSkipAuthentication = SecurityRestUtils.shouldSkipAuthentication(requestChannel); boolean shouldDecompress = !shouldSkipAuthentication && requestChannel.getQueuedResponse().isEmpty(); if (requestChannel.getQueuedResponse().isEmpty() || shouldSkipAuthentication) { diff --git a/src/test/java/org/opensearch/security/filter/SecurityRestUtilsTests.java b/src/test/java/org/opensearch/security/filter/SecurityRestUtilsTests.java new file mode 100644 index 0000000000..0424d780ef --- /dev/null +++ b/src/test/java/org/opensearch/security/filter/SecurityRestUtilsTests.java @@ -0,0 +1,63 @@ +package org.opensearch.security.filter; + +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.Test; +import org.opensearch.http.netty4.Netty4HttpChannel; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; + +public class SecurityRestUtilsTests { + + @Test + public void testShouldSkipAuthentication_positive() { + FullHttpRequest request1 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/"); + NettyRequestChannel requestChannel1 = new NettyRequestChannel(request1, mock(Netty4HttpChannel.class)); + + assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel1)); + + FullHttpRequest request2 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_plugins/_security/health"); + NettyRequestChannel requestChannel2 = new NettyRequestChannel(request2, mock(Netty4HttpChannel.class)); + + assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel2)); + + FullHttpRequest request3 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_plugins/_security/whoami"); + NettyRequestChannel requestChannel3 = new NettyRequestChannel(request3, mock(Netty4HttpChannel.class)); + + assertTrue(SecurityRestUtils.shouldSkipAuthentication(requestChannel3)); + } + + @Test + public void testShouldSkipAuthentication_negative() { + FullHttpRequest request1 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + NettyRequestChannel requestChannel1 = new NettyRequestChannel(request1, mock(Netty4HttpChannel.class)); + + assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel1)); + + FullHttpRequest request2 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/_cluster/health"); + NettyRequestChannel requestChannel2 = new NettyRequestChannel(request2, mock(Netty4HttpChannel.class)); + + assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel2)); + + FullHttpRequest request3 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/my-index/_search"); + NettyRequestChannel requestChannel3 = new NettyRequestChannel(request3, mock(Netty4HttpChannel.class)); + + assertFalse(SecurityRestUtils.shouldSkipAuthentication(requestChannel3)); + } + + @Test + public void testGetRawPath() { + String rawPathWithParams = "/_cluster/health?pretty"; + String rawPathWithoutParams = "/my-index/search"; + + String path1 = SecurityRestUtils.path(rawPathWithParams); + String path2 = SecurityRestUtils.path(rawPathWithoutParams); + + assertTrue("/_cluster/health".equals(path1)); + assertTrue("/my-index/search".equals(path2)); + } +}