Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ scoped-tls = "1.0.0"
slab = "0.4.2"
libc = "0.2.80"
io-uring = { version = "0.5.0", features = [ "unstable" ] }
os_socketaddr = "0.2.0"
socket2 = { version = "0.4.4", features = [ "all"] }
bytes = { version = "1.0", optional = true }

[dev-dependencies]
Expand Down
32 changes: 32 additions & 0 deletions examples/unix_listener.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use std::env;

use tokio_uring::net::UnixListener;

fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: String = args[1].clone();

tokio_uring::start(async {
let listener = UnixListener::bind(&socket_addr).unwrap();

loop {
let stream = listener.accept().await.unwrap();
let socket_addr = socket_addr.clone();
tokio_uring::spawn(async move {
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written to {}: {}", &socket_addr, result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read from {}: {:?}", &socket_addr, &buf[..read]);
});
}
});
}
25 changes: 25 additions & 0 deletions examples/unix_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use std::env;

use tokio_uring::net::UnixStream;

fn main() {
let args: Vec<_> = env::args().collect();

if args.len() <= 1 {
panic!("no addr specified");
}

let socket_addr: &String = &args[1];

tokio_uring::start(async {
let stream = UnixStream::connect(socket_addr).await.unwrap();
let buf = vec![1u8; 128];

let (result, buf) = stream.write(buf).await;
println!("written: {}", result.unwrap());

let (result, buf) = stream.read(buf).await;
let read = result.unwrap();
println!("read: {:?}", &buf[..read]);
});
}
16 changes: 7 additions & 9 deletions src/driver/connect.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,28 @@
use crate::driver::{Op, SharedFd};
use os_socketaddr::OsSocketAddr;
use std::{io, net::SocketAddr};
use socket2::SockAddr;
use std::io;

/// Open a file
pub(crate) struct Connect {
fd: SharedFd,
os_socket_addr: OsSocketAddr,
socket_addr: SockAddr,
}

impl Op<Connect> {
/// Submit a request to connect.
pub(crate) fn connect(fd: &SharedFd, socket_addr: SocketAddr) -> io::Result<Op<Connect>> {
pub(crate) fn connect(fd: &SharedFd, socket_addr: SockAddr) -> io::Result<Op<Connect>> {
use io_uring::{opcode, types};

let os_socket_addr = OsSocketAddr::from(socket_addr);

Op::submit_with(
Connect {
fd: fd.clone(),
os_socket_addr,
socket_addr,
},
|connect| {
opcode::Connect::new(
types::Fd(connect.fd.raw_fd()),
connect.os_socket_addr.as_ptr(),
connect.os_socket_addr.len(),
connect.socket_addr.as_ptr(),
connect.socket_addr.len(),
)
.build()
},
Expand Down
2 changes: 1 addition & 1 deletion src/driver/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ where
Poll::Ready(Completion {
data: me.data.take().expect("unexpected operation state"),
result,
flags: flags,
flags,
})
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/driver/recv_from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::{
driver::{Op, SharedFd},
BufResult,
};
use os_socketaddr::OsSocketAddr;
use socket2::SockAddr;
use std::{
io::IoSliceMut,
task::{Context, Poll},
Expand All @@ -15,7 +15,7 @@ pub(crate) struct RecvFrom<T> {
fd: SharedFd,
pub(crate) buf: T,
io_slices: Vec<IoSliceMut<'static>>,
pub(crate) os_socket_addr: Box<OsSocketAddr>,
pub(crate) socket_addr: Box<SockAddr>,
pub(crate) msghdr: Box<libc::msghdr>,
}

Expand All @@ -27,20 +27,20 @@ impl<T: IoBufMut> Op<RecvFrom<T>> {
std::slice::from_raw_parts_mut(buf.stable_mut_ptr(), buf.bytes_total())
})];

let mut os_socket_addr = Box::new(OsSocketAddr::new());
let socket_addr = Box::new(unsafe { SockAddr::init(|_, _| Ok(()))?.1 });

let mut msghdr: Box<libc::msghdr> = Box::new(unsafe { std::mem::zeroed() });
msghdr.msg_iov = io_slices.as_mut_ptr().cast();
msghdr.msg_iovlen = io_slices.len() as _;
msghdr.msg_name = os_socket_addr.as_mut_ptr() as *mut libc::c_void;
msghdr.msg_namelen = os_socket_addr.capacity();
msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
msghdr.msg_namelen = socket_addr.len();

Op::submit_with(
RecvFrom {
fd: fd.clone(),
buf,
io_slices,
os_socket_addr,
socket_addr,
msghdr,
},
|recv_from| {
Expand Down Expand Up @@ -74,7 +74,7 @@ impl<T: IoBufMut> Op<RecvFrom<T>> {
let result = match complete.result {
Ok(v) => {
let v = v as usize;
let socket_addr: Option<SocketAddr> = (*complete.data.os_socket_addr).into();
let socket_addr: Option<SocketAddr> = (*complete.data.socket_addr).as_socket();
// If the operation was successful, advance the initialized cursor.
// Safety: the kernel wrote `v` bytes to the buffer.
unsafe {
Expand Down
12 changes: 6 additions & 6 deletions src/driver/send_to.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::buf::IoBuf;
use crate::driver::{Op, SharedFd};
use crate::BufResult;
use os_socketaddr::OsSocketAddr;
use socket2::SockAddr;
use std::io::IoSlice;
use std::task::{Context, Poll};
use std::{boxed::Box, io, net::SocketAddr};
Expand All @@ -13,7 +13,7 @@ pub(crate) struct SendTo<T> {
#[allow(dead_code)]
io_slices: Vec<IoSlice<'static>>,
#[allow(dead_code)]
os_socket_addr: Box<OsSocketAddr>,
socket_addr: Box<SockAddr>,
pub(crate) msghdr: Box<libc::msghdr>,
}

Expand All @@ -29,20 +29,20 @@ impl<T: IoBuf> Op<SendTo<T>> {
std::slice::from_raw_parts(buf.stable_ptr(), buf.bytes_init())
})];

let mut os_socket_addr = Box::new(OsSocketAddr::from(socket_addr));
let socket_addr = Box::new(SockAddr::from(socket_addr));

let mut msghdr: Box<libc::msghdr> = Box::new(unsafe { std::mem::zeroed() });
msghdr.msg_iov = io_slices.as_ptr() as *mut _;
msghdr.msg_iovlen = io_slices.len() as _;
msghdr.msg_name = os_socket_addr.as_mut_ptr() as *mut libc::c_void;
msghdr.msg_namelen = os_socket_addr.len();
msghdr.msg_name = socket_addr.as_ptr() as *mut libc::c_void;
msghdr.msg_namelen = socket_addr.len();

Op::submit_with(
SendTo {
fd: fd.clone(),
buf,
io_slices,
os_socket_addr,
socket_addr,
msghdr,
},
|send_to| {
Expand Down
76 changes: 56 additions & 20 deletions src/driver/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use crate::{
buf::{IoBuf, IoBufMut},
driver::{Op, SharedFd},
};
use os_socketaddr::OsSocketAddr;
use std::{
io,
net::SocketAddr,
os::unix::io::{AsRawFd, RawFd},
os::unix::io::{AsRawFd, IntoRawFd, RawFd},
path::Path,
};

#[derive(Clone)]
Expand All @@ -26,7 +26,15 @@ impl Socket {
pub(crate) fn new(socket_addr: SocketAddr, socket_type: libc::c_int) -> io::Result<Socket> {
let socket_type = socket_type | libc::SOCK_CLOEXEC;
let domain = get_domain(socket_addr);
let fd = syscall!(socket(domain, socket_type, 0))?;
let fd = socket2::Socket::new(domain.into(), socket_type.into(), None)?.into_raw_fd();
let fd = SharedFd::new(fd);
Ok(Socket { fd })
}

pub(crate) fn new_unix(socket_type: libc::c_int) -> io::Result<Socket> {
let socket_type = socket_type | libc::SOCK_CLOEXEC;
let domain = libc::AF_UNIX;
let fd = socket2::Socket::new(domain.into(), socket_type.into(), None)?.into_raw_fd();
let fd = SharedFd::new(fd);
Ok(Socket { fd })
}
Expand Down Expand Up @@ -58,38 +66,66 @@ impl Socket {
op.recv().await
}

pub(crate) async fn accept(&self) -> io::Result<(Socket, SocketAddr)> {
pub(crate) async fn accept(&self) -> io::Result<(Socket, Option<SocketAddr>)> {
let op = Op::accept(&self.fd)?;
let completion = op.await;
let fd = completion.result?;
let fd = SharedFd::new(fd as i32);
let data = completion.data;
let socket = Socket { fd };
let os_socket_addr = unsafe {
OsSocketAddr::from_raw_parts(
&completion.data.socketaddr.0 as *const _ as _,
completion.data.socketaddr.1 as usize,
)
let (_, addr) = unsafe {
socket2::SockAddr::init(move |addr_storage, len| {
*addr_storage = data.socketaddr.0.to_owned();
*len = data.socketaddr.1;
Ok(())
})?
};
let socket_addr = os_socket_addr.into_addr().unwrap();
Ok((socket, socket_addr))
Ok((socket, addr.as_socket()))
}

pub(crate) async fn connect(&self, socket_addr: SocketAddr) -> io::Result<()> {
pub(crate) async fn connect(&self, socket_addr: socket2::SockAddr) -> io::Result<()> {
let op = Op::connect(&self.fd, socket_addr)?;
let completion = op.await;
completion.result?;
Ok(())
}

pub(crate) fn bind(socket_addr: SocketAddr, socket_type: libc::c_int) -> io::Result<Socket> {
let socket = Socket::new(socket_addr, socket_type)?;
let os_socket_addr = OsSocketAddr::from(socket_addr);
syscall!(bind(
socket.as_raw_fd(),
os_socket_addr.as_ptr(),
os_socket_addr.len()
))?;
Ok(socket)
Self::bind_internal(
socket_addr.into(),
get_domain(socket_addr).into(),
socket_type.into(),
)
}

pub(crate) fn bind_unix<P: AsRef<Path>>(
path: P,
socket_type: libc::c_int,
) -> io::Result<Socket> {
let addr = socket2::SockAddr::unix(path.as_ref())?;
Self::bind_internal(addr, libc::AF_UNIX.into(), socket_type.into())
}

fn bind_internal(
socket_addr: socket2::SockAddr,
domain: socket2::Domain,
socket_type: socket2::Type,
) -> io::Result<Socket> {
let sys_listener = socket2::Socket::new(domain, socket_type, None)?;
let addr = socket2::SockAddr::from(socket_addr);

sys_listener.set_reuse_port(true)?;
sys_listener.set_reuse_address(true)?;

// TODO: config for buffer sizes
// sys_listener.set_send_buffer_size(send_buf_size)?;
// sys_listener.set_recv_buffer_size(recv_buf_size)?;

sys_listener.bind(&addr)?;

let fd = SharedFd::new(sys_listener.into_raw_fd());

Ok(Self { fd })
}

pub(crate) fn listen(&self, backlog: libc::c_int) -> io::Result<()> {
Expand Down
2 changes: 2 additions & 0 deletions src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

mod tcp;
mod udp;
mod unix;

pub use tcp::{TcpListener, TcpStream};
pub use udp::UdpSocket;
pub use unix::{UnixListener, UnixStream};
13 changes: 9 additions & 4 deletions src/net/tcp/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ impl TcpListener {
/// The returned listener is ready for accepting connections.
///
/// Binding with a port number of 0 will request that the OS assigns a port
/// to this listener. The port allocated can be queried via the `local_addr`
/// to this listener.
///
/// In the future, the port allocated can be queried via a (blocking) `local_addr`
/// method.
pub fn bind(socket_addr: SocketAddr) -> io::Result<TcpListener> {
let socket = Socket::bind(socket_addr, libc::SOCK_STREAM)?;
pub fn bind(addr: SocketAddr) -> io::Result<Self> {
let socket = Socket::bind(addr, libc::SOCK_STREAM)?;
socket.listen(1024)?;
Ok(TcpListener { inner: socket })
return Ok(TcpListener { inner: socket });
}

/// Accepts a new incoming connection from this listener.
Expand All @@ -59,6 +61,9 @@ impl TcpListener {
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (socket, socket_addr) = self.inner.accept().await?;
let stream = TcpStream { inner: socket };
let socket_addr = socket_addr.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "Could not get socket IP address")
})?;
Ok((stream, socket_addr))
}
}
Loading