|
33 | 33 | import io.grpc.ServerCall.Listener; |
34 | 34 | import io.grpc.internal.NoopServerCall; |
35 | 35 | import java.io.ByteArrayInputStream; |
| 36 | +import java.io.IOException; |
36 | 37 | import java.io.InputStream; |
37 | 38 | import java.util.ArrayList; |
38 | 39 | import java.util.Arrays; |
@@ -425,6 +426,91 @@ public void onMessage(ReqT message) { |
425 | 426 | order); |
426 | 427 | } |
427 | 428 |
|
| 429 | + /** |
| 430 | + * Tests the {@link ServerInterceptors#useMarshalledMessages(ServerServiceDefinition, Marshaller, Marshaller)}. |
| 431 | + * Makes sure that on incoming request the request marshaller's stream method is called and on response the |
| 432 | + * response marshaller's parse method is called |
| 433 | + */ |
| 434 | + @Test |
| 435 | + public void distinctMarshallerForRequestAndResponse() { |
| 436 | + final List<String> requestFlowOrder = new ArrayList<>(); |
| 437 | + |
| 438 | + final Marshaller<String> requestMarshaller = new Marshaller<String>() { |
| 439 | + @Override |
| 440 | + public InputStream stream(String value) { |
| 441 | + requestFlowOrder.add("RequestStream"); |
| 442 | + return new ByteArrayInputStream(value.getBytes()); |
| 443 | + } |
| 444 | + |
| 445 | + @Override |
| 446 | + public String parse(InputStream stream) { |
| 447 | + requestFlowOrder.add("RequestParse"); |
| 448 | + try { |
| 449 | + byte[] bytes = new byte[stream.available()]; |
| 450 | + stream.read(bytes); |
| 451 | + return new String(bytes); |
| 452 | + } catch (IOException e) { |
| 453 | + throw new RuntimeException(e); |
| 454 | + } |
| 455 | + } |
| 456 | + }; |
| 457 | + final Marshaller<String> responseMarshaller = new Marshaller<String>() { |
| 458 | + @Override |
| 459 | + public InputStream stream(String value) { |
| 460 | + requestFlowOrder.add("ResponseStream"); |
| 461 | + return new ByteArrayInputStream(value.getBytes()); |
| 462 | + } |
| 463 | + |
| 464 | + @Override |
| 465 | + public String parse(InputStream stream) { |
| 466 | + requestFlowOrder.add("ResponseParse"); |
| 467 | + try { |
| 468 | + byte[] bytes = new byte[stream.available()]; |
| 469 | + stream.read(bytes); |
| 470 | + return new String(bytes); |
| 471 | + } catch (IOException e) { |
| 472 | + throw new RuntimeException(e); |
| 473 | + } |
| 474 | + } |
| 475 | + }; |
| 476 | + final Marshaller<Holder> dummyMarshaller = new Marshaller<Holder>() { |
| 477 | + @Override |
| 478 | + public InputStream stream(Holder value) { |
| 479 | + return value.get(); |
| 480 | + } |
| 481 | + |
| 482 | + @Override |
| 483 | + public Holder parse(InputStream stream) { |
| 484 | + return new Holder(stream); |
| 485 | + } |
| 486 | + }; |
| 487 | + ServerCallHandler<Holder, Holder> handler = (call, headers) -> new Listener<Holder>() { |
| 488 | + @Override |
| 489 | + public void onMessage(Holder message) { |
| 490 | + requestFlowOrder.add("handler"); |
| 491 | + call.sendMessage(message); |
| 492 | + } |
| 493 | + }; |
| 494 | + |
| 495 | + MethodDescriptor<Holder, Holder> wrappedMethod = MethodDescriptor.<Holder, Holder>newBuilder() |
| 496 | + .setType(MethodType.UNKNOWN) |
| 497 | + .setFullMethodName("basic/wrapped") |
| 498 | + .setRequestMarshaller(dummyMarshaller) |
| 499 | + .setResponseMarshaller(dummyMarshaller) |
| 500 | + .build(); |
| 501 | + ServerServiceDefinition serviceDef = ServerServiceDefinition.builder( |
| 502 | + new ServiceDescriptor("basic", wrappedMethod)) |
| 503 | + .addMethod(wrappedMethod, handler).build(); |
| 504 | + ServerServiceDefinition intercepted = ServerInterceptors.useMarshalledMessages(serviceDef, requestMarshaller, |
| 505 | + responseMarshaller); |
| 506 | + ServerMethodDefinition<String, String> serverMethod = |
| 507 | + (ServerMethodDefinition<String, String>) intercepted.getMethod("basic/wrapped"); |
| 508 | + ServerCall<String, String> serverCall = new NoopServerCall<>(); |
| 509 | + serverMethod.getServerCallHandler().startCall(serverCall, headers).onMessage("TestMessage"); |
| 510 | + |
| 511 | + assertEquals(Arrays.asList("RequestStream", "handler", "ResponseParse"), requestFlowOrder); |
| 512 | + } |
| 513 | + |
428 | 514 | @SuppressWarnings("unchecked") |
429 | 515 | private static ServerMethodDefinition<String, Integer> getSoleMethod( |
430 | 516 | ServerServiceDefinition serviceDef) { |
|
0 commit comments