1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ use async_trait:: async_trait;
1516use log:: debug;
1617use pingora_error:: { Context , Error , ErrorType :: * , OrErr , Result } ;
1718use rand:: seq:: SliceRandom ;
@@ -26,6 +27,12 @@ use crate::protocols::l4::stream::Stream;
2627use crate :: protocols:: { GetSocketDigest , SocketDigest } ;
2728use crate :: upstreams:: peer:: Peer ;
2829
30+ /// The interface to establish a L4 connection
31+ #[ async_trait]
32+ pub trait Connect : std:: fmt:: Debug {
33+ async fn connect ( & self , addr : & SocketAddr ) -> Result < Stream > ;
34+ }
35+
2936/// Establish a connection (l4) to the given peer using its settings and an optional bind address.
3037pub async fn connect < P > ( peer : & P , bind_to : Option < InetSocketAddr > ) -> Result < Stream >
3138where
@@ -37,72 +44,78 @@ where
3744 . err_context ( || format ! ( "Fail to establish CONNECT proxy: {}" , peer) ) ;
3845 }
3946 let peer_addr = peer. address ( ) ;
40- let mut stream: Stream = match peer_addr {
41- SocketAddr :: Inet ( addr) => {
42- let connect_future = tcp_connect ( addr, bind_to. as_ref ( ) , |socket| {
43- if peer. tcp_fast_open ( ) {
44- set_tcp_fastopen_connect ( socket. as_raw_fd ( ) ) ?;
45- }
46- if let Some ( recv_buf) = peer. tcp_recv_buf ( ) {
47- debug ! ( "Setting recv buf size" ) ;
48- set_recv_buf ( socket. as_raw_fd ( ) , recv_buf) ?;
49- }
50- if let Some ( dscp) = peer. dscp ( ) {
51- debug ! ( "Setting dscp" ) ;
52- set_dscp ( socket. as_raw_fd ( ) , dscp) ?;
53- }
54- Ok ( ( ) )
55- } ) ;
56- let conn_res = match peer. connection_timeout ( ) {
57- Some ( t) => pingora_timeout:: timeout ( t, connect_future)
58- . await
59- . explain_err ( ConnectTimedout , |_| {
60- format ! ( "timeout {t:?} connecting to server {peer}" )
61- } ) ?,
62- None => connect_future. await ,
63- } ;
64- match conn_res {
65- Ok ( socket) => {
66- debug ! ( "connected to new server: {}" , peer. address( ) ) ;
67- Ok ( socket. into ( ) )
68- }
69- Err ( e) => {
70- let c = format ! ( "Fail to connect to {peer}" ) ;
71- match e. etype ( ) {
72- SocketError | BindError => Error :: e_because ( InternalError , c, e) ,
73- _ => Err ( e. more_context ( c) ) ,
47+ let mut stream: Stream =
48+ if let Some ( custom_l4) = peer. get_peer_options ( ) . and_then ( |o| o. custom_l4 . as_ref ( ) ) {
49+ custom_l4. connect ( peer_addr) . await ?
50+ } else {
51+ match peer_addr {
52+ SocketAddr :: Inet ( addr) => {
53+ let connect_future = tcp_connect ( addr, bind_to. as_ref ( ) , |socket| {
54+ if peer. tcp_fast_open ( ) {
55+ set_tcp_fastopen_connect ( socket. as_raw_fd ( ) ) ?;
56+ }
57+ if let Some ( recv_buf) = peer. tcp_recv_buf ( ) {
58+ debug ! ( "Setting recv buf size" ) ;
59+ set_recv_buf ( socket. as_raw_fd ( ) , recv_buf) ?;
60+ }
61+ if let Some ( dscp) = peer. dscp ( ) {
62+ debug ! ( "Setting dscp" ) ;
63+ set_dscp ( socket. as_raw_fd ( ) , dscp) ?;
64+ }
65+ Ok ( ( ) )
66+ } ) ;
67+ let conn_res = match peer. connection_timeout ( ) {
68+ Some ( t) => pingora_timeout:: timeout ( t, connect_future)
69+ . await
70+ . explain_err ( ConnectTimedout , |_| {
71+ format ! ( "timeout {t:?} connecting to server {peer}" )
72+ } ) ?,
73+ None => connect_future. await ,
74+ } ;
75+ match conn_res {
76+ Ok ( socket) => {
77+ debug ! ( "connected to new server: {}" , peer. address( ) ) ;
78+ Ok ( socket. into ( ) )
79+ }
80+ Err ( e) => {
81+ let c = format ! ( "Fail to connect to {peer}" ) ;
82+ match e. etype ( ) {
83+ SocketError | BindError => Error :: e_because ( InternalError , c, e) ,
84+ _ => Err ( e. more_context ( c) ) ,
85+ }
86+ }
7487 }
7588 }
76- }
77- }
78- SocketAddr :: Unix ( addr ) => {
79- let connect_future = connect_uds (
80- addr . as_pathname ( )
81- . expect ( "non-pathname unix sockets not supported as peer" ) ,
82- ) ;
83- let conn_res = match peer . connection_timeout ( ) {
84- Some ( t ) => pingora_timeout :: timeout ( t , connect_future )
85- . await
86- . explain_err ( ConnectTimedout , |_| {
87- format ! ( "timeout {t:?} connecting to server {peer}" )
88- } ) ? ,
89- None => connect_future . await ,
90- } ;
91- match conn_res {
92- Ok ( socket) => {
93- debug ! ( "connected to new server: {}" , peer . address ( ) ) ;
94- Ok ( socket . into ( ) )
95- }
96- Err ( e ) => {
97- let c = format ! ( "Fail to connect to {peer}" ) ;
98- match e . etype ( ) {
99- SocketError | BindError => Error :: e_because ( InternalError , c , e ) ,
100- _ => Err ( e . more_context ( c ) ) ,
89+ SocketAddr :: Unix ( addr ) => {
90+ let connect_future = connect_uds (
91+ addr . as_pathname ( )
92+ . expect ( "non-pathname unix sockets not supported as peer" ) ,
93+ ) ;
94+ let conn_res = match peer. connection_timeout ( ) {
95+ Some ( t ) => pingora_timeout :: timeout ( t , connect_future )
96+ . await
97+ . explain_err ( ConnectTimedout , |_| {
98+ format ! ( "timeout {t:?} connecting to server {peer}" )
99+ } ) ? ,
100+ None => connect_future . await ,
101+ } ;
102+ match conn_res {
103+ Ok ( socket ) => {
104+ debug ! ( "connected to new server: {}" , peer . address ( ) ) ;
105+ Ok ( socket. into ( ) )
106+ }
107+ Err ( e ) => {
108+ let c = format ! ( "Fail to connect to {peer}" ) ;
109+ match e . etype ( ) {
110+ SocketError | BindError => Error :: e_because ( InternalError , c , e ) ,
111+ _ => Err ( e . more_context ( c ) ) ,
112+ }
113+ }
101114 }
102115 }
103- }
104- }
105- } ? ;
116+ } ?
117+ } ;
118+
106119 let tracer = peer. get_tracer ( ) ;
107120 if let Some ( t) = tracer {
108121 t. 0 . on_connected ( ) ;
@@ -249,6 +262,29 @@ mod tests {
249262 assert_eq ! ( new_session. unwrap_err( ) . etype( ) , & ConnectTimedout )
250263 }
251264
265+ #[ tokio:: test]
266+ async fn test_custom_connect ( ) {
267+ #[ derive( Debug ) ]
268+ struct MyL4 ;
269+ #[ async_trait]
270+ impl Connect for MyL4 {
271+ async fn connect ( & self , _addr : & SocketAddr ) -> Result < Stream > {
272+ tokio:: net:: TcpStream :: connect ( "1.1.1.1:80" )
273+ . await
274+ . map ( |s| s. into ( ) )
275+ . or_fail ( )
276+ }
277+ }
278+ // :79 shouldn't be able to be connected to
279+ let mut peer = BasicPeer :: new ( "1.1.1.1:79" ) ;
280+ peer. options . custom_l4 = Some ( std:: sync:: Arc :: new ( MyL4 { } ) ) ;
281+
282+ let new_session = connect ( & peer, None ) . await ;
283+
284+ // but MyL4 connects to :80 instead
285+ assert ! ( new_session. is_ok( ) ) ;
286+ }
287+
252288 #[ tokio:: test]
253289 async fn test_connect_proxy_fail ( ) {
254290 let mut peer = HttpPeer :: new ( "1.1.1.1:80" . to_string ( ) , false , "" . to_string ( ) ) ;
0 commit comments