Skip to content

Commit 23f5a6f

Browse files
committed
Add support for http forward proxy with CONNECT
This is a squash and modification of master commits that also includes: netty,okhttp: Fix CONNECT and its error handling This commit has been modified to reduce its size to substantially reduce risk of it breaking Netty error handling. But that also means proxy error handling just provides a useless "there was an error" sort of message. There is no Java API to enable the proxy support. Instead, you must set the GRPC_PROXY_EXP environment variable which should be set to a host:port string. The environment variable is temporary; it will not exist in future releases. It exists to provide support without needing explicit code to enable the future, while at the same time not risking enabling it for existing users.
1 parent 5bfac21 commit 23f5a6f

File tree

8 files changed

+489
-12
lines changed

8 files changed

+489
-12
lines changed

build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ subprojects {
163163

164164
netty: 'io.netty:netty-codec-http2:[4.1.7.Final]',
165165
netty_epoll: 'io.netty:netty-transport-native-epoll:4.1.7.Final' + epoll_suffix,
166+
netty_proxy_handler: 'io.netty:netty-handler-proxy:4.1.7.Final',
166167
netty_tcnative: 'io.netty:netty-tcnative-boringssl-static:1.1.33.Fork25',
167168

168169
// Test dependencies.

netty/build.gradle

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
description = "gRPC: Netty"
22
dependencies {
33
compile project(':grpc-core'),
4-
libraries.netty
4+
libraries.netty,
5+
libraries.netty_proxy_handler
56

67
// Tests depend on base class defined by core module.
78
testCompile project(':grpc-core').sourceSets.test.output,

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,25 @@ static ProtocolNegotiator createProtocolNegotiator(
291291
String authority,
292292
NegotiationType negotiationType,
293293
SslContext sslContext) {
294+
ProtocolNegotiator negotiator =
295+
createProtocolNegotiatorByType(authority, negotiationType, sslContext);
296+
String proxy = System.getenv("GRPC_PROXY_EXP");
297+
if (proxy != null) {
298+
String[] parts = proxy.split(":", 2);
299+
int port = 80;
300+
if (parts.length > 1) {
301+
port = Integer.parseInt(parts[1]);
302+
}
303+
InetSocketAddress proxyAddress = new InetSocketAddress(parts[0], port);
304+
negotiator = ProtocolNegotiators.httpProxy(proxyAddress, null, null, negotiator);
305+
}
306+
return negotiator;
307+
}
308+
309+
private static ProtocolNegotiator createProtocolNegotiatorByType(
310+
String authority,
311+
NegotiationType negotiationType,
312+
SslContext sslContext) {
294313
switch (negotiationType) {
295314
case PLAINTEXT:
296315
return ProtocolNegotiators.plaintext();

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import io.netty.channel.ChannelHandler;
4646
import io.netty.channel.ChannelHandlerAdapter;
4747
import io.netty.channel.ChannelHandlerContext;
48+
import io.netty.channel.ChannelInboundHandler;
4849
import io.netty.channel.ChannelInboundHandlerAdapter;
4950
import io.netty.channel.ChannelPipeline;
5051
import io.netty.channel.ChannelPromise;
@@ -54,13 +55,17 @@
5455
import io.netty.handler.codec.http.HttpMethod;
5556
import io.netty.handler.codec.http.HttpVersion;
5657
import io.netty.handler.codec.http2.Http2ClientUpgradeCodec;
58+
import io.netty.handler.proxy.HttpProxyHandler;
59+
import io.netty.handler.proxy.ProxyConnectionEvent;
60+
import io.netty.handler.proxy.ProxyHandler;
5761
import io.netty.handler.ssl.OpenSsl;
5862
import io.netty.handler.ssl.OpenSslEngine;
5963
import io.netty.handler.ssl.SslContext;
6064
import io.netty.handler.ssl.SslHandler;
6165
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
6266
import io.netty.util.AsciiString;
6367
import io.netty.util.ReferenceCountUtil;
68+
import java.net.SocketAddress;
6469
import java.net.URI;
6570
import java.util.ArrayDeque;
6671
import java.util.Arrays;
@@ -189,6 +194,73 @@ public AsciiString scheme() {
189194
}
190195
}
191196

197+
/**
198+
* Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation.
199+
*/
200+
public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress,
201+
final @Nullable String proxyUsername, final @Nullable String proxyPassword,
202+
final ProtocolNegotiator negotiator) {
203+
Preconditions.checkNotNull(proxyAddress, "proxyAddress");
204+
Preconditions.checkNotNull(negotiator, "negotiator");
205+
class ProxyNegotiator implements ProtocolNegotiator {
206+
@Override
207+
public Handler newHandler(GrpcHttp2ConnectionHandler http2Handler) {
208+
HttpProxyHandler proxyHandler;
209+
if (proxyUsername == null || proxyPassword == null) {
210+
proxyHandler = new HttpProxyHandler(proxyAddress);
211+
} else {
212+
proxyHandler = new HttpProxyHandler(proxyAddress, proxyUsername, proxyPassword);
213+
}
214+
return new BufferUntilProxyTunnelledHandler(
215+
proxyHandler, negotiator.newHandler(http2Handler));
216+
}
217+
}
218+
219+
return new ProxyNegotiator();
220+
}
221+
222+
/**
223+
* Buffers all writes until the HTTP CONNECT tunnel is established.
224+
*/
225+
static final class BufferUntilProxyTunnelledHandler extends AbstractBufferingHandler
226+
implements ProtocolNegotiator.Handler {
227+
private final ProtocolNegotiator.Handler originalHandler;
228+
229+
public BufferUntilProxyTunnelledHandler(
230+
ProxyHandler proxyHandler, ProtocolNegotiator.Handler handler) {
231+
super(proxyHandler, handler);
232+
this.originalHandler = handler;
233+
}
234+
235+
236+
@Override
237+
public AsciiString scheme() {
238+
return originalHandler.scheme();
239+
}
240+
241+
@Override
242+
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
243+
if (evt instanceof ProxyConnectionEvent) {
244+
writeBufferedAndRemove(ctx);
245+
}
246+
super.userEventTriggered(ctx, evt);
247+
}
248+
249+
@Override
250+
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
251+
fail(ctx, unavailableException("Connection broken while trying to CONNECT through proxy"));
252+
super.channelInactive(ctx);
253+
}
254+
255+
@Override
256+
public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception {
257+
if (ctx.channel().isActive()) { // This may be a notification that the socket was closed
258+
fail(ctx, unavailableException("Channel closed while trying to CONNECT through proxy"));
259+
}
260+
super.close(ctx, future);
261+
}
262+
}
263+
192264
/**
193265
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
194266
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
@@ -366,10 +438,22 @@ public void channelRegistered(ChannelHandlerContext ctx) throws Exception {
366438
* lifetime and we only want to configure it once.
367439
*/
368440
if (handlers != null) {
369-
ctx.pipeline().addFirst(handlers);
441+
for (ChannelHandler handler : handlers) {
442+
ctx.pipeline().addBefore(ctx.name(), null, handler);
443+
}
444+
ChannelHandler handler0 = handlers[0];
445+
ChannelHandlerContext handler0Ctx = ctx.pipeline().context(handlers[0]);
370446
handlers = null;
447+
if (handler0Ctx != null) { // The handler may have removed itself immediately
448+
if (handler0 instanceof ChannelInboundHandler) {
449+
((ChannelInboundHandler) handler0).channelRegistered(handler0Ctx);
450+
} else {
451+
handler0Ctx.fireChannelRegistered();
452+
}
453+
}
454+
} else {
455+
super.channelRegistered(ctx);
371456
}
372-
super.channelRegistered(ctx);
373457
}
374458

375459
@Override
@@ -424,7 +508,10 @@ public void flush(ChannelHandlerContext ctx) {
424508

425509
@Override
426510
public void close(ChannelHandlerContext ctx, ChannelPromise future) throws Exception {
427-
fail(ctx, unavailableException("Channel closed while performing protocol negotiation"));
511+
if (ctx.channel().isActive()) { // This may be a notification that the socket was closed
512+
fail(ctx, unavailableException("Channel closed while performing protocol negotiation"));
513+
}
514+
super.close(ctx, future);
428515
}
429516

430517
protected final void fail(ChannelHandlerContext ctx, Throwable cause) {

netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java

Lines changed: 113 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,25 +31,41 @@
3131

3232
package io.grpc.netty;
3333

34+
import static com.google.common.base.Charsets.UTF_8;
3435
import static org.junit.Assert.assertEquals;
3536
import static org.junit.Assert.assertFalse;
3637
import static org.junit.Assert.assertNotNull;
3738
import static org.junit.Assert.assertNull;
3839
import static org.junit.Assert.assertTrue;
40+
import static org.mockito.Matchers.any;
3941
import static org.mockito.Mockito.mock;
42+
import static org.mockito.Mockito.times;
4043

4144
import io.grpc.netty.ProtocolNegotiators.ServerTlsHandler;
4245
import io.grpc.netty.ProtocolNegotiators.TlsNegotiator;
4346
import io.grpc.testing.TestUtils;
47+
import io.netty.bootstrap.Bootstrap;
48+
import io.netty.bootstrap.ServerBootstrap;
49+
import io.netty.buffer.ByteBuf;
50+
import io.netty.buffer.ByteBufUtil;
51+
import io.netty.channel.Channel;
52+
import io.netty.channel.ChannelFuture;
4453
import io.netty.channel.ChannelHandler;
4554
import io.netty.channel.ChannelHandlerContext;
55+
import io.netty.channel.ChannelInboundHandler;
4656
import io.netty.channel.ChannelPipeline;
57+
import io.netty.channel.DefaultEventLoopGroup;
4758
import io.netty.channel.embedded.EmbeddedChannel;
59+
import io.netty.channel.local.LocalAddress;
60+
import io.netty.channel.local.LocalChannel;
61+
import io.netty.channel.local.LocalServerChannel;
4862
import io.netty.handler.ssl.SslContext;
4963
import io.netty.handler.ssl.SslHandler;
5064
import io.netty.handler.ssl.SslHandshakeCompletionEvent;
5165
import io.netty.handler.ssl.SupportedCipherSuiteFilter;
5266
import java.io.File;
67+
import java.net.InetSocketAddress;
68+
import java.net.SocketAddress;
5369
import java.util.logging.Filter;
5470
import java.util.logging.Level;
5571
import java.util.logging.LogRecord;
@@ -63,10 +79,17 @@
6379
import org.junit.rules.ExpectedException;
6480
import org.junit.runner.RunWith;
6581
import org.junit.runners.JUnit4;
82+
import org.mockito.ArgumentCaptor;
83+
import org.mockito.Mockito;
6684

6785
@RunWith(JUnit4.class)
6886
public class ProtocolNegotiatorsTest {
69-
@Rule public final ExpectedException thrown = ExpectedException.none();
87+
private static final Runnable NOOP_RUNNABLE = new Runnable() {
88+
@Override public void run() {}
89+
};
90+
91+
@Rule
92+
public final ExpectedException thrown = ExpectedException.none();
7093

7194
private GrpcHttp2ConnectionHandler grpcHandler = mock(GrpcHttp2ConnectionHandler.class);
7295

@@ -81,7 +104,7 @@ public void setUp() throws Exception {
81104
File serverCert = TestUtils.loadCert("server1.pem");
82105
File key = TestUtils.loadCert("server1.key");
83106
sslContext = GrpcSslContexts.forServer(serverCert, key)
84-
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
107+
.ciphers(TestUtils.preferredTestCiphers(), SupportedCipherSuiteFilter.INSTANCE).build();
85108
engine = SSLContext.getDefault().createSSLEngine();
86109
}
87110

@@ -272,4 +295,92 @@ public void tls_invalidHost() throws SSLException {
272295
assertEquals("bad_host:1234", negotiator.getHost());
273296
assertEquals(-1, negotiator.getPort());
274297
}
298+
299+
@Test
300+
public void httpProxy_nullAddressNpe() throws Exception {
301+
thrown.expect(NullPointerException.class);
302+
ProtocolNegotiators.httpProxy(null, "user", "pass", ProtocolNegotiators.plaintext());
303+
}
304+
305+
@Test
306+
public void httpProxy_nullNegotiatorNpe() throws Exception {
307+
thrown.expect(NullPointerException.class);
308+
ProtocolNegotiators.httpProxy(
309+
InetSocketAddress.createUnresolved("localhost", 80), "user", "pass", null);
310+
}
311+
312+
@Test
313+
public void httpProxy_nullUserPassNoException() throws Exception {
314+
assertNotNull(ProtocolNegotiators.httpProxy(
315+
InetSocketAddress.createUnresolved("localhost", 80), null, null,
316+
ProtocolNegotiators.plaintext()));
317+
}
318+
319+
@Test(timeout = 5000)
320+
public void httpProxy_completes() throws Exception {
321+
DefaultEventLoopGroup elg = new DefaultEventLoopGroup(1);
322+
// ProxyHandler is incompatible with EmbeddedChannel because when channelRegistered() is called
323+
// the channel is already active.
324+
LocalAddress proxy = new LocalAddress("httpProxy_completes");
325+
SocketAddress host = InetSocketAddress.createUnresolved("specialHost", 314);
326+
327+
ChannelInboundHandler mockHandler = mock(ChannelInboundHandler.class);
328+
Channel serverChannel = new ServerBootstrap().group(elg).channel(LocalServerChannel.class)
329+
.childHandler(mockHandler)
330+
.bind(proxy).sync().channel();
331+
332+
ProtocolNegotiator nego =
333+
ProtocolNegotiators.httpProxy(proxy, null, null, ProtocolNegotiators.plaintext());
334+
ChannelHandler handler = nego.newHandler(grpcHandler);
335+
Channel channel = new Bootstrap().group(elg).channel(LocalChannel.class).handler(handler)
336+
.register().sync().channel();
337+
pipeline = channel.pipeline();
338+
// Wait for initialization to complete
339+
channel.eventLoop().submit(NOOP_RUNNABLE).sync();
340+
// The grpcHandler must be in the pipeline, but we don't actually want it during our test
341+
// because it will consume all events since it is a mock. We only use it because it is required
342+
// to construct the Handler.
343+
pipeline.remove(grpcHandler);
344+
channel.connect(host).sync();
345+
serverChannel.close();
346+
ArgumentCaptor<ChannelHandlerContext> contextCaptor =
347+
ArgumentCaptor.forClass(ChannelHandlerContext.class);
348+
Mockito.verify(mockHandler).channelActive(contextCaptor.capture());
349+
ChannelHandlerContext serverContext = contextCaptor.getValue();
350+
351+
final String golden = "isThisThingOn?";
352+
ChannelFuture negotiationFuture = channel.writeAndFlush(bb(golden, channel));
353+
354+
// Wait for sending initial request to complete
355+
channel.eventLoop().submit(NOOP_RUNNABLE).sync();
356+
ArgumentCaptor<Object> objectCaptor = ArgumentCaptor.forClass(Object.class);
357+
Mockito.verify(mockHandler)
358+
.channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
359+
ByteBuf b = (ByteBuf) objectCaptor.getValue();
360+
String request = b.toString(UTF_8);
361+
b.release();
362+
assertTrue("No trailing newline: " + request, request.endsWith("\r\n\r\n"));
363+
assertTrue("No CONNECT: " + request, request.startsWith("CONNECT specialHost:314 "));
364+
assertTrue("No host header: " + request, request.contains("host: specialHost:314"));
365+
366+
assertFalse(negotiationFuture.isDone());
367+
serverContext.writeAndFlush(bb("HTTP/1.1 200 OK\r\n\r\n", serverContext.channel())).sync();
368+
negotiationFuture.sync();
369+
370+
channel.eventLoop().submit(NOOP_RUNNABLE).sync();
371+
objectCaptor.getAllValues().clear();
372+
Mockito.verify(mockHandler, times(2))
373+
.channelRead(any(ChannelHandlerContext.class), objectCaptor.capture());
374+
b = (ByteBuf) objectCaptor.getAllValues().get(1);
375+
// If we were using the real grpcHandler, this would have been the HTTP/2 preface
376+
String preface = b.toString(UTF_8);
377+
b.release();
378+
assertEquals(golden, preface);
379+
380+
channel.close();
381+
}
382+
383+
private static ByteBuf bb(String s, Channel c) {
384+
return ByteBufUtil.writeUtf8(c.alloc(), s);
385+
}
275386
}

okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,20 @@ public ConnectionClientTransport newClientTransport(
311311
if (closed) {
312312
throw new IllegalStateException("The transport factory is closed.");
313313
}
314+
InetSocketAddress proxyAddress = null;
315+
String proxy = System.getenv("GRPC_PROXY_EXP");
316+
if (proxy != null) {
317+
String[] parts = proxy.split(":", 2);
318+
int port = 80;
319+
if (parts.length > 1) {
320+
port = Integer.parseInt(parts[1]);
321+
}
322+
proxyAddress = new InetSocketAddress(parts[0], port);
323+
}
314324
InetSocketAddress inetSocketAddr = (InetSocketAddress) addr;
315325
OkHttpClientTransport transport = new OkHttpClientTransport(inetSocketAddr, authority,
316-
userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize);
326+
userAgent, executor, socketFactory, Utils.convertSpec(connectionSpec), maxMessageSize,
327+
proxyAddress, null, null);
317328
if (enableKeepAlive) {
318329
transport.enableKeepAlive(true, keepAliveDelayNanos, keepAliveTimeoutNanos);
319330
}

0 commit comments

Comments
 (0)