Skip to content

Commit 8666e83

Browse files
committed
Prevent blocking all virtual serial ports on full PTS
1 parent ffcf4b4 commit 8666e83

File tree

1 file changed

+38
-8
lines changed

1 file changed

+38
-8
lines changed

src/lib.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1+
use bytes::{Bytes, Buf};
12
use camino::{Utf8Path, Utf8PathBuf};
23
use thiserror::Error;
3-
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
4+
use tokio::io::{AsyncRead, AsyncWrite};
45
#[cfg(unix)]
56
use tokio_serial::SerialPort;
67
use tokio_serial::SerialPortBuilderExt;
78
use tokio_serial::SerialStream;
89
use tokio_stream::{StreamExt, StreamMap};
910
use tokio_util::io::ReaderStream;
10-
use tracing::{error, info};
11+
use tracing::{error, warn, info};
1112

1213
use std::collections::HashMap;
1314
use std::fs;
15+
use std::io::ErrorKind;
16+
use std::pin::Pin;
17+
use std::task::Poll::Ready;
1418

1519
#[cfg(unix)]
1620
use std::os::unix;
@@ -108,15 +112,41 @@ where
108112
let bytes = result.map_err(Error::Read)?;
109113
info!(?src_id, ?dst_ids, ?bytes, "read");
110114
for dst_id in dst_ids {
111-
// This unwrap is OK as long as we validate all route IDs exist first
112-
// Route IDs are validated in Args::check_route_ids()
113-
let dst = sinks.get_mut(dst_id).unwrap();
114-
let mut buf = bytes.clone();
115-
dst.write_all_buf(&mut buf).await.map_err(Error::Write)?;
116-
info!(?dst_id, ?bytes, "wrote");
115+
if let Some(dst) = sinks.get_mut(dst_id) {
116+
let mut buf = bytes.clone();
117+
if let Err(e) = write_non_blocking(dst, &mut buf).await {
118+
if let Error::Write(io_err) = &e {
119+
if io_err.kind() == ErrorKind::WouldBlock {
120+
warn!(?dst_id, ?bytes, "discarded");
121+
} else {
122+
error!(?dst_id, ?e, "write error");
123+
}
124+
}
125+
} else {
126+
info!(?dst_id, ?bytes, "wrote");
127+
}
128+
}
117129
}
118130
}
119131
}
120132

121133
Ok(())
122134
}
135+
136+
async fn write_non_blocking<W: AsyncWrite + Unpin>(
137+
dst: &mut W,
138+
buf: &mut Bytes,
139+
) -> Result<()> {
140+
let waker = futures::task::noop_waker();
141+
let mut cx = futures::task::Context::from_waker(&waker);
142+
143+
let pinned_dst = Pin::new(dst);
144+
match pinned_dst.poll_write(&mut cx, buf) {
145+
Ready(Ok(bytes_written)) => {
146+
buf.advance(bytes_written);
147+
Ok(())
148+
}
149+
Ready(Err(e)) => Err(Error::Write(e)),
150+
_ => Err(Error::Write(std::io::Error::new(ErrorKind::WouldBlock, "Would block"))),
151+
}
152+
}

0 commit comments

Comments
 (0)