Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package io.quarkus.resteasy.reactive.server.test.headers;

import static io.restassured.RestAssured.when;
import static org.assertj.core.api.Assertions.*;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.Produces;
import jakarta.ws.rs.core.MediaType;
import jakarta.ws.rs.core.Response;

import org.jboss.resteasy.reactive.RestResponse;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.restassured.http.Headers;

public class IgnoredResponseHeadersTest {

@RegisterExtension
static QuarkusUnitTest TEST = new QuarkusUnitTest()
.withApplicationRoot((jar) -> jar.addClasses(Resource.class));

@Test
public void testResponse() {
doTest("response");
}

@Test
public void testRestResponse() {
doTest("rest-response");
}

private static void doTest(String path) {
Headers responseHeaders = when()
.get("/resource/" + path)
.then()
.statusCode(200)
.extract().headers();

assertThat(responseHeaders.getList("Transfer-Encoding"))
.extracting("value")
.singleElement().isEqualTo("chunked");
}

@Path("resource")
public static class Resource {

@Path("response")
@Produces(MediaType.TEXT_PLAIN)
@GET
public Response response() {
return Response.ok(largeString())
.header("Transfer-Encoding", "chunked")
.header("Content-Type", "text/plain")
.build();
}

@Path("rest-response")
@Produces(MediaType.TEXT_PLAIN)
@GET
public RestResponse<InputStream> restResponse() {
return RestResponse.ResponseBuilder.ok(largeString())
.header("Transfer-Encoding", "chunked")
.header("Content-Type", "text/plain")
.build();
}

private static InputStream largeString() {
String content = IntStream.range(1, 100_000).mapToObj(i -> "Hello no." + i).collect(Collectors.joining(","));
return new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import static io.restassured.RestAssured.when;
import static org.assertj.core.api.Assertions.assertThat;
import static org.hamcrest.CoreMatchers.is;

import java.io.IOException;

Expand All @@ -20,7 +19,6 @@

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.vertx.web.RouteFilter;
import io.restassured.http.Headers;
import io.vertx.ext.web.RoutingContext;

public class VertxHeadersTest {
Expand All @@ -38,16 +36,6 @@ void testVaryHeaderValues() {
assertThat(headers.getValues(HttpHeaders.VARY)).containsExactlyInAnyOrder("Origin", "Prefer");
}

@Test
void testTransferEncodingHeaderValues() {
Headers headers = when().get("/test/response")
.then()
.statusCode(200)
.header("Transfer-Encoding", is("chunked")).extract().headers();

assertThat(headers.asList()).noneMatch(h -> h.getName().equals("transfer-encoding"));
}

public static class VertxFilter {
@RouteFilter
void addVary(final RoutingContext rc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ public void accept(ResteasyReactiveRequestContext context) {
private static final String LENGTH = "Length";
private static final String LENGTH_LOWER = "length";
private static final String CONTENT_TYPE = CONTENT + "-" + TYPE; // use this instead of the Vert.x constant because the TCK expects upper case
private static final String TRANSFER_ENCODING = "Transfer-Encoding";
public static final String TRANSFER_ENCODING = "Transfer-Encoding";

public final static List<Serialisers.BuiltinReader> BUILTIN_READERS = List.of(
new Serialisers.BuiltinReader(String.class, ServerStringMessageBodyHandler.class,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import java.io.ByteArrayInputStream;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

import jakarta.ws.rs.core.GenericEntity;
import jakarta.ws.rs.core.HttpHeaders;
Expand All @@ -15,6 +17,7 @@
import org.jboss.resteasy.reactive.server.core.EncodedMediaType;
import org.jboss.resteasy.reactive.server.core.LazyResponse;
import org.jboss.resteasy.reactive.server.core.ResteasyReactiveRequestContext;
import org.jboss.resteasy.reactive.server.core.ServerSerialisers;
import org.jboss.resteasy.reactive.server.jaxrs.ResponseBuilderImpl;
import org.jboss.resteasy.reactive.server.spi.ServerRestHandler;

Expand All @@ -26,6 +29,10 @@ public class ResponseHandler implements ServerRestHandler {

public static final ResponseHandler NO_CUSTOMIZER_INSTANCE = new ResponseHandler();

// TODO: we need to think about what other headers coming from the existing Response need to be ignored
private static final Set<String> IGNORED_HEADERS = Collections.singleton(ServerSerialisers.TRANSFER_ENCODING.toLowerCase(
Locale.ROOT));

private final List<ResponseBuilderCustomizer> responseBuilderCustomizers;

public ResponseHandler(List<ResponseBuilderCustomizer> responseBuilderCustomizers) {
Expand All @@ -39,14 +46,12 @@ private ResponseHandler() {
@Override
public void handle(ResteasyReactiveRequestContext requestContext) throws Exception {
Object result = requestContext.getResult();
if (result instanceof Response) {
if (result instanceof Response existing) {
boolean mediaTypeAlreadyExists = false;
//we already have a response
//set it explicitly
ResponseBuilderImpl responseBuilder;
Response existing = (Response) result;
if (existing.getEntity() instanceof GenericEntity) {
GenericEntity<?> genericEntity = (GenericEntity<?>) existing.getEntity();
if (existing.getEntity() instanceof GenericEntity<?> genericEntity) {
requestContext.setGenericReturnType(genericEntity.getType());
responseBuilder = fromResponse(existing);
responseBuilder.entity(genericEntity.getEntity());
Expand All @@ -56,17 +61,15 @@ public void handle(ResteasyReactiveRequestContext requestContext) throws Excepti
requestContext.setGenericReturnType(existing.getEntity().getClass());
//TODO: super inefficient
responseBuilder = fromResponse(existing);
if ((result instanceof ResponseImpl)) {
if ((result instanceof ResponseImpl responseImpl)) {
// needed in order to preserve entity annotations
ResponseImpl responseImpl = (ResponseImpl) result;
if (responseImpl.getEntityAnnotations() != null) {
requestContext.setAdditionalAnnotations(responseImpl.getEntityAnnotations());
}

// this is a weird case where the response comes from the the rest-client
// this is a weird case where the response comes from the rest-client
if (responseBuilder.getEntity() == null) {
if (responseImpl.getEntityStream() instanceof ByteArrayInputStream) {
ByteArrayInputStream byteArrayInputStream = (ByteArrayInputStream) responseImpl.getEntityStream();
if (responseImpl.getEntityStream() instanceof ByteArrayInputStream byteArrayInputStream) {
responseBuilder.entity(byteArrayInputStream.readAllBytes());
}
}
Expand All @@ -88,31 +91,27 @@ public void handle(ResteasyReactiveRequestContext requestContext) throws Excepti
} else {
requestContext.setResponse(new LazyResponse.Existing(responseBuilder.build()));
}
} else if (result instanceof RestResponse) {
} else if (result instanceof RestResponse<?> existing) {
boolean mediaTypeAlreadyExists = false;
//we already have a response
//set it explicitly
ResponseBuilderImpl responseBuilder;
RestResponse<?> existing = (RestResponse<?>) result;
if (existing.getEntity() instanceof GenericEntity) {
GenericEntity<?> genericEntity = (GenericEntity<?>) existing.getEntity();
if (existing.getEntity() instanceof GenericEntity<?> genericEntity) {
requestContext.setGenericReturnType(genericEntity.getType());
responseBuilder = fromResponse(existing);
responseBuilder.entity(genericEntity.getEntity());
} else {
//TODO: super inefficient
responseBuilder = fromResponse(existing);
if ((result instanceof RestResponseImpl)) {
if ((result instanceof RestResponseImpl<?> responseImpl)) {
// needed in order to preserve entity annotations
RestResponseImpl<?> responseImpl = (RestResponseImpl<?>) result;
if (responseImpl.getEntityAnnotations() != null) {
requestContext.setAdditionalAnnotations(responseImpl.getEntityAnnotations());
}

// this is a weird case where the response comes from the the rest-client
// this is a weird case where the response comes from the rest-client
if (responseBuilder.getEntity() == null) {
if (responseImpl.getEntityStream() instanceof ByteArrayInputStream) {
ByteArrayInputStream byteArrayInputStream = (ByteArrayInputStream) responseImpl.getEntityStream();
if (responseImpl.getEntityStream() instanceof ByteArrayInputStream byteArrayInputStream) {
responseBuilder.entity(byteArrayInputStream.readAllBytes());
}
}
Expand Down Expand Up @@ -198,6 +197,9 @@ private ResponseBuilderImpl fromResponse(Response response) {
var headers = response.getHeaders();
if (headers != null) {
for (String headerName : headers.keySet()) {
if (IGNORED_HEADERS.contains(headerName.toLowerCase(Locale.ROOT))) {
continue;
}
List<Object> headerValues = headers.get(headerName);
for (Object headerValue : headerValues) {
b.header(headerName, headerValue);
Expand All @@ -214,6 +216,9 @@ private ResponseBuilderImpl fromResponse(RestResponse<?> response) {
b.entity(response.getEntity());
}
for (String headerName : response.getHeaders().keySet()) {
if (IGNORED_HEADERS.contains(headerName.toLowerCase(Locale.ROOT))) {
continue;
}
List<Object> headerValues = response.getHeaders().get(headerName);
for (Object headerValue : headerValues) {
b.header(headerName, headerValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.nullValue;

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

import jakarta.ws.rs.GET;
import jakarta.ws.rs.Path;
import jakarta.ws.rs.core.Response;
Expand Down Expand Up @@ -34,7 +40,12 @@ public static class TestResource {
@GET
@Path("hello")
public Response hello() {
return Response.ok("hello").header("Transfer-Encoding", "chunked").build();
return Response.ok(largeString()).header("Transfer-Encoding", "chunked").build();
}

private static InputStream largeString() {
String content = IntStream.range(1, 100_000).mapToObj(i -> "Hello no." + i).collect(Collectors.joining(","));
return new ByteArrayInputStream(content.getBytes(StandardCharsets.UTF_8));
}
}
}
Loading