Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WebSocket NEXT: automatically close connection when OIDC extension provides SecurityIdentity and token expires #40857

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/main/asciidoc/websockets-next-reference.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -639,8 +639,10 @@
quarkus.http.auth.permission.secured.policy=authenticated
----

Other options for securing HTTP upgrade requests, such as using the security annotations, will be explored in the future.

Check warning on line 642 in docs/src/main/asciidoc/websockets-next-reference.adoc

View workflow job for this annotation

GitHub Actions / Linting with Vale

[vale] reported by reviewdog 🐶 [Quarkus.TermsSuggestions] Depending on the context, consider using 'by using' or 'that uses' rather than 'using'. Raw Output: {"message": "[Quarkus.TermsSuggestions] Depending on the context, consider using 'by using' or 'that uses' rather than 'using'.", "location": {"path": "docs/src/main/asciidoc/websockets-next-reference.adoc", "range": {"start": {"line": 642, "column": 58}}}, "severity": "INFO"}

NOTE: When OpenID Connect extension is used and token expires, Quarkus automatically closes connection.

[[websocket-next-configuration-reference]]
== Configuration reference

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package io.quarkus.websockets.next.test.security;

import static io.quarkus.websockets.next.test.security.SecurityTestBase.basicAuth;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.time.Duration;
import java.util.concurrent.atomic.AtomicReference;

import jakarta.inject.Inject;
import jakarta.inject.Singleton;

import org.awaitility.Awaitility;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.security.Authenticated;
import io.quarkus.security.identity.AuthenticationRequestContext;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.security.identity.SecurityIdentityAugmentor;
import io.quarkus.security.runtime.QuarkusSecurityIdentity;
import io.quarkus.security.test.utils.TestIdentityController;
import io.quarkus.security.test.utils.TestIdentityProvider;
import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.CloseReason;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.smallrye.mutiny.Uni;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;

public class AuthenticationExpiredTest {

@Inject
Vertx vertx;

@TestHTTPResource("end")
URI endUri;

@BeforeAll
public static void setupUsers() {
TestIdentityController.resetRoles()
.add("admin", "admin", "admin")
.add("user", "user", "user");
}

@RegisterExtension
static final QuarkusUnitTest config = new QuarkusUnitTest()
.withApplicationRoot(root -> root.addClasses(Endpoint.class, TestIdentityProvider.class,
TestIdentityController.class, WSClient.class, ExpiredIdentityAugmentor.class, SecurityTestBase.class));

@Test
public void testConnectionClosedWhenAuthExpires() {
try (WSClient client = new WSClient(vertx)) {
client.connect(basicAuth("admin", "admin"), endUri);

long threeSecondsFromNow = Duration.ofMillis(System.currentTimeMillis()).plusSeconds(3).toMillis();
for (int i = 1; true; i++) {
if (client.isClosed()) {
break;
} else if (System.currentTimeMillis() > threeSecondsFromNow) {
Assertions.fail("Authentication expired, therefore connection should had been closed");
}
client.sendAndAwaitReply("Hello #" + i + " from ");
}

var receivedMessages = client.getMessages().stream().map(Buffer::toString).toList();
assertTrue(receivedMessages.size() > 2, receivedMessages.toString());
assertTrue(receivedMessages.contains("Hello #1 from admin"), receivedMessages.toString());
assertTrue(receivedMessages.contains("Hello #2 from admin"), receivedMessages.toString());
assertEquals(1008, client.closeStatusCode(), "Expected close status 1008, but got " + client.closeStatusCode());

Awaitility
.await()
.atMost(Duration.ofSeconds(1))
.untilAsserted(() -> assertTrue(Endpoint.CLOSED_MESSAGE.get()
.startsWith("Connection closed with reason 'Authentication expired'")));
}
}

@Singleton
public static class ExpiredIdentityAugmentor implements SecurityIdentityAugmentor {

@Override
public Uni<SecurityIdentity> augment(SecurityIdentity securityIdentity,
AuthenticationRequestContext authenticationRequestContext) {
return Uni
.createFrom()
.item(QuarkusSecurityIdentity
.builder(securityIdentity)
.addAttribute("quarkus.identity.expire-time", expireIn2Seconds())
.build());
}

private static long expireIn2Seconds() {
return Duration.ofMillis(System.currentTimeMillis())
.plusSeconds(2)
.toSeconds();
}
}

@WebSocket(path = "/end")
public static class Endpoint {

static final AtomicReference<String> CLOSED_MESSAGE = new AtomicReference<>();

@Inject
SecurityIdentity currentIdentity;

@Authenticated
@OnTextMessage
String echo(String message) {
return message + currentIdentity.getPrincipal().getName();
}

@OnClose
void close(CloseReason reason, WebSocketConnection connection) {
CLOSED_MESSAGE.set("Connection closed with reason '%s': %s".formatted(reason.getMessage(), connection));
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ public void handle(Void event) {
handleFailure(unhandledFailureStrategy, r.cause(), "Unable to complete @OnClose callback",
connection);
}
securitySupport.onClose();
onClose.run();
if (timerId != null) {
vertx.cancelTimer(timerId);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,36 @@
package io.quarkus.websockets.next.runtime;

import java.util.Objects;
import java.util.concurrent.TimeUnit;

import jakarta.enterprise.inject.Instance;

import org.jboss.logging.Logger;

import io.quarkus.security.identity.CurrentIdentityAssociation;
import io.quarkus.security.identity.SecurityIdentity;
import io.quarkus.websockets.next.CloseReason;
import io.vertx.core.Vertx;

public class SecuritySupport {

static final SecuritySupport NOOP = new SecuritySupport(null, null);
private static final Logger LOG = Logger.getLogger(SecuritySupport.class);
static final SecuritySupport NOOP = new SecuritySupport(null, null, null, null);

private final Instance<CurrentIdentityAssociation> currentIdentity;
private final SecurityIdentity identity;
private final Runnable onClose;

SecuritySupport(Instance<CurrentIdentityAssociation> currentIdentity, SecurityIdentity identity) {
SecuritySupport(Instance<CurrentIdentityAssociation> currentIdentity, SecurityIdentity identity, Vertx vertx,
WebSocketConnectionImpl connection) {
this.currentIdentity = currentIdentity;
this.identity = currentIdentity != null ? Objects.requireNonNull(identity) : identity;
if (this.currentIdentity != null) {
this.identity = Objects.requireNonNull(identity);
this.onClose = closeConnectionWhenIdentityExpired(vertx, connection, this.identity);
} else {
this.identity = null;
this.onClose = null;
}
}

/**
Expand All @@ -29,4 +43,25 @@ void start() {
}
}

void onClose() {
if (onClose != null) {
onClose.run();
}
}

private static Runnable closeConnectionWhenIdentityExpired(Vertx vertx, WebSocketConnectionImpl connection,
SecurityIdentity identity) {
if (identity.getAttribute("quarkus.identity.expire-time") instanceof Long expireAt) {
long timerId = vertx.setTimer(TimeUnit.SECONDS.toMillis(expireAt) - System.currentTimeMillis(),
ignored -> connection
.close(new CloseReason(1008, "Authentication expired"))
.subscribe()
.with(
v -> LOG.tracef("Closed connection due to expired authentication: %s", connection),
e -> LOG.errorf("Unable to close connection [%s] after authentication "
+ "expired due to unhandled failure: %s", connection, e)));
return () -> vertx.cancelTimer(timerId);
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ public Handler<RoutingContext> createEndpointHandler(String generatedEndpointCla

@Override
public void handle(RoutingContext ctx) {
SecuritySupport securitySupport = initializeSecuritySupport(container, ctx);

Future<ServerWebSocket> future = ctx.request().toWebSocket();
future.onSuccess(ws -> {
Vertx vertx = VertxCoreRecorder.getVertx().get();
Expand All @@ -101,6 +99,8 @@ public void handle(RoutingContext ctx) {
connectionManager.add(generatedEndpointClass, connection);
LOG.debugf("Connection created: %s", connection);

SecuritySupport securitySupport = initializeSecuritySupport(container, ctx, vertx, connection);

Endpoints.initialize(vertx, container, codecs, connection, ws, generatedEndpointClass,
config.autoPingInterval(), securitySupport, config.unhandledFailureStrategy(),
() -> connectionManager.remove(generatedEndpointClass, connection));
Expand All @@ -109,14 +109,15 @@ public void handle(RoutingContext ctx) {
};
}

SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx) {
SecuritySupport initializeSecuritySupport(ArcContainer container, RoutingContext ctx, Vertx vertx,
WebSocketConnectionImpl connection) {
Instance<CurrentIdentityAssociation> currentIdentityAssociation = container.select(CurrentIdentityAssociation.class);
if (currentIdentityAssociation.isResolvable()) {
// Security extension is present
// Obtain the current security identity from the handshake request
QuarkusHttpUser user = (QuarkusHttpUser) ctx.user();
if (user != null) {
return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity());
return new SecuritySupport(currentIdentityAssociation, user.getSecurityIdentity(), vertx, connection);
}
}
return SecuritySupport.NOOP;
Expand Down