Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package io.quarkus.websockets.next.test.connection;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.net.URI;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OnTextMessage;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketClient;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketConnector;
import io.quarkus.websockets.next.test.utils.WSClient;

public class ConnectionIdleTimeoutTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(ServerEndpoint.class, ClientEndpoint.class, WSClient.class);
}).overrideConfigKey("quarkus.websockets-next.client.connection-idle-timeout", "500ms");;

@TestHTTPResource("/")
URI uri;

@Inject
WebSocketConnector<ClientEndpoint> connector;

@Test
public void testTimeout() throws InterruptedException {
WebSocketClientConnection conn = connector.baseUri(uri.toString()).connectAndAwait();
ExecutorService executor = Executors.newSingleThreadExecutor();
try {
TimeUnit.MILLISECONDS.sleep(500);
executor.execute(() -> {
try {
conn.sendTextAndAwait("ok");
} catch (Throwable ignored) {
}
});
} finally {
executor.shutdownNow();
}
assertTrue(ServerEndpoint.CLOSED.await(5, TimeUnit.SECONDS));
assertTrue(ClientEndpoint.CLOSED.await(5, TimeUnit.SECONDS));
assertFalse(ServerEndpoint.MESSAGE.get());
}

@WebSocket(path = "/end")
public static class ServerEndpoint {

static final CountDownLatch CLOSED = new CountDownLatch(1);
static final AtomicBoolean MESSAGE = new AtomicBoolean();

@OnTextMessage
void onText(String message) {
MESSAGE.set(true);
}

@OnClose
void close() {
CLOSED.countDown();
}

}

@WebSocketClient(path = "/end")
public static class ClientEndpoint {

static final CountDownLatch CLOSED = new CountDownLatch(1);

@OnOpen
void open() {
}

@OnClose
void close(WebSocketClientConnection conn) {
CLOSED.countDown();
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ public Uni<WebSocketClientConnection> connect() {
throw new WebSocketClientException(e);
}

Uni<WebSocket> websocket = Uni.createFrom().<WebSocket> emitter(e -> {
Uni<WebSocketOpen> websocketOpen = Uni.createFrom().<WebSocketOpen> emitter(e -> {
// Create a new event loop context for each client, otherwise the current context is used
// We want to avoid a situation where if multiple clients/connections are created in a row,
// the same event loop is used and so writing/receiving messages is de-facto serialized
Expand All @@ -171,7 +171,7 @@ public void handle(Void event) {
@Override
public void handle(AsyncResult<WebSocket> r) {
if (r.succeeded()) {
e.complete(r.result());
e.complete(new WebSocketOpen(newCleanupConsumer(c, context), r.result()));
} else {
e.fail(r.cause());
}
Expand All @@ -183,14 +183,20 @@ public void handle(AsyncResult<WebSocket> r) {
}
});
});
return websocket.map(ws -> {
return websocketOpen.map(wsOpen -> {
WebSocket ws = wsOpen.websocket();
String clientId = BasicWebSocketConnector.class.getName();
TrafficLogger trafficLogger = TrafficLogger.forClient(config);
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId, ws,
WebSocketClientConnectionImpl connection = new WebSocketClientConnectionImpl(clientId,
ws,
codecs,
pathParams,
serverEndpointUri,
headers, trafficLogger, userData, null);
headers,
trafficLogger,
userData,
null,
wsOpen.cleanup());
if (trafficLogger != null) {
trafficLogger.connectionOpened(connection);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

import java.util.Iterator;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.stream.Stream;

import jakarta.annotation.PreDestroy;
import jakarta.enterprise.event.Event;
import jakarta.inject.Singleton;

import org.jboss.logging.Logger;

import io.quarkus.arc.Arc;
import io.quarkus.arc.ArcContainer;
import io.quarkus.runtime.Shutdown;
import io.quarkus.websockets.next.Closed;
import io.quarkus.websockets.next.Open;
import io.quarkus.websockets.next.OpenClientConnections;
Expand All @@ -26,7 +27,7 @@ public class ClientConnectionManager implements OpenClientConnections {

private static final Logger LOG = Logger.getLogger(ClientConnectionManager.class);

private final ConcurrentMap<String, Set<WebSocketClientConnection>> endpointToConnections = new ConcurrentHashMap<>();
private final ConcurrentMap<String, Set<WebSocketClientConnectionImpl>> endpointToConnections = new ConcurrentHashMap<>();

private final List<ClientConnectionListener> listeners = new CopyOnWriteArrayList<>();

Expand All @@ -50,10 +51,11 @@ public Iterator<WebSocketClientConnection> iterator() {

@Override
public Stream<WebSocketClientConnection> stream() {
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen);
return endpointToConnections.values().stream().flatMap(Set::stream).filter(WebSocketClientConnection::isOpen)
.map(WebSocketClientConnection.class::cast);
}

void add(String endpoint, WebSocketClientConnection connection) {
void add(String endpoint, WebSocketClientConnectionImpl connection) {
LOG.debugf("Add client connection: %s", connection);
if (endpointToConnections.computeIfAbsent(endpoint, e -> ConcurrentHashMap.newKeySet()).add(connection)) {
if (openEvent != null) {
Expand All @@ -72,9 +74,9 @@ void add(String endpoint, WebSocketClientConnection connection) {
}
}

void remove(String endpoint, WebSocketClientConnection connection) {
void remove(String endpoint, WebSocketClientConnectionImpl connection) {
LOG.debugf("Remove client connection: %s", connection);
Set<WebSocketClientConnection> connections = endpointToConnections.get(endpoint);
Set<WebSocketClientConnectionImpl> connections = endpointToConnections.get(endpoint);
if (connections != null) {
if (connections.remove(connection)) {
if (closedEvent != null) {
Expand All @@ -99,8 +101,8 @@ void remove(String endpoint, WebSocketClientConnection connection) {
* @param endpoint
* @return the connections for the given client endpoint, never {@code null}
*/
public Set<WebSocketClientConnection> getConnections(String endpoint) {
Set<WebSocketClientConnection> ret = endpointToConnections.get(endpoint);
public Set<WebSocketClientConnectionImpl> getConnections(String endpoint) {
Set<WebSocketClientConnectionImpl> ret = endpointToConnections.get(endpoint);
if (ret == null) {
return Set.of();
}
Expand All @@ -111,9 +113,19 @@ public void addListener(ClientConnectionListener listener) {
this.listeners.add(listener);
}

@PreDestroy
void destroy() {
endpointToConnections.clear();
@Shutdown
void cleanup() {
if (!endpointToConnections.isEmpty()) {
int sum = 0;
for (Entry<String, Set<WebSocketClientConnectionImpl>> e : endpointToConnections.entrySet()) {
for (WebSocketClientConnectionImpl c : e.getValue()) {
c.cleanup();
sum++;
}
}
LOG.debugf("Cleanup performed for %s connections", sum);
endpointToConnections.clear();
}
}

public interface ClientConnectionListener {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.function.Consumer;

import io.quarkus.websockets.next.HandshakeRequest;
import io.quarkus.websockets.next.WebSocketClientConnection;
Expand All @@ -19,13 +20,17 @@ class WebSocketClientConnectionImpl extends WebSocketConnectionBase implements W

private final WebSocket webSocket;

private final Consumer<WebSocketClientConnection> cleanup;

WebSocketClientConnectionImpl(String clientId, WebSocket webSocket, Codecs codecs,
Map<String, String> pathParams, URI serverEndpointUri, Map<String, List<String>> headers,
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor) {
TrafficLogger trafficLogger, Map<String, Object> userData, SendingInterceptor sendingInterceptor,
Consumer<WebSocketClientConnection> cleanup) {
super(Map.copyOf(pathParams), codecs, new ClientHandshakeRequestImpl(serverEndpointUri, headers), trafficLogger,
new UserDataImpl(userData), sendingInterceptor);
this.clientId = clientId;
this.webSocket = Objects.requireNonNull(webSocket);
this.cleanup = cleanup;
}

@Override
Expand All @@ -48,6 +53,12 @@ public int hashCode() {
return Objects.hash(identifier);
}

protected void cleanup() {
if (cleanup != null) {
cleanup.accept(this);
}
}

@Override
public boolean equals(Object obj) {
if (this == obj)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.net.URI;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
Expand All @@ -11,23 +12,33 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.jboss.logging.Logger;

import io.quarkus.tls.TlsConfiguration;
import io.quarkus.tls.TlsConfigurationRegistry;
import io.quarkus.tls.runtime.config.TlsConfigUtils;
import io.quarkus.websockets.next.UserData.TypedKey;
import io.quarkus.websockets.next.WebSocketClientConnection;
import io.quarkus.websockets.next.WebSocketClientException;
import io.quarkus.websockets.next.runtime.config.WebSocketsClientRuntimeConfig;
import io.vertx.core.Vertx;
import io.vertx.core.http.WebSocket;
import io.vertx.core.http.WebSocketClient;
import io.vertx.core.http.WebSocketClientOptions;
import io.vertx.core.http.WebSocketConnectOptions;
import io.vertx.core.impl.ContextImpl;

abstract class WebSocketConnectorBase<THIS extends WebSocketConnectorBase<THIS>> {

protected static final Pattern PATH_PARAM_PATTERN = Pattern.compile("\\{[a-zA-Z0-9_]+\\}");

private static final Logger LOG = Logger.getLogger(WebSocketConnectorBase.class);

// mutable state

protected URI baseUri;
Expand Down Expand Up @@ -172,7 +183,16 @@ protected WebSocketClientOptions populateClientOptions() {
if (config.maxFrameSize().isPresent()) {
clientOptions.setMaxFrameSize(config.maxFrameSize().getAsInt());
}

if (config.connectionIdleTimeout().isPresent()) {
Duration timeout = config.connectionIdleTimeout().get();
if (timeout.toMillis() > Integer.MAX_VALUE) {
clientOptions.setIdleTimeoutUnit(TimeUnit.SECONDS);
clientOptions.setIdleTimeout((int) timeout.toSeconds());
} else {
clientOptions.setIdleTimeoutUnit(TimeUnit.MILLISECONDS);
clientOptions.setIdleTimeout((int) timeout.toMillis());
}
}
Optional<TlsConfiguration> maybeTlsConfiguration = TlsConfiguration.from(tlsConfigurationRegistry,
Optional.ofNullable(tlsConfigurationName));
if (maybeTlsConfiguration.isEmpty()) {
Expand Down Expand Up @@ -201,4 +221,28 @@ protected WebSocketConnectOptions newConnectOptions(URI serverEndpointUri) {
protected boolean isSecure(URI uri) {
return "https".equals(uri.getScheme()) || "wss".equals(uri.getScheme());
}

record WebSocketOpen(Consumer<WebSocketClientConnection> cleanup, WebSocket websocket) {
}

Consumer<WebSocketClientConnection> newCleanupConsumer(WebSocketClient client, ContextImpl context) {
return new Consumer<WebSocketClientConnection>() {
@Override
public void accept(WebSocketClientConnection conn) {
try {
client.close();
LOG.debugf("Client closed for connection %s", conn.id());
} catch (Throwable e) {
LOG.errorf(e, "Unable to close the client for connection %s", conn.id());
}
try {
context.close();
LOG.debugf("Context closed for connection %s", conn.id());
} catch (Throwable e) {
LOG.errorf(e, "Unable to close the context for connection %s", conn.id());
}
}
};
}

}
Loading
Loading