Skip to content

fix(rs-nats): close streams on drop #341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
[package]
description = "WebAssembly component-native RPC framework based on WIT"
name = "wrpc"
version = "0.11.0"
version = "0.11.1"

authors.workspace = true
categories.workspace = true
edition.workspace = true
homepage.workspace = true
license.workspace = true
repository.workspace = true

Expand Down Expand Up @@ -140,6 +141,6 @@ wrpc-cli = { version = "0.3", path = "./crates/cli", default-features = false }
wrpc-introspect = { version = "0.3", default-features = false, path = "./crates/introspect" }
wrpc-runtime-wasmtime = { version = "0.22", path = "./crates/runtime-wasmtime", default-features = false }
wrpc-transport = { version = "0.26.8", path = "./crates/transport", default-features = false }
wrpc-transport-nats = { version = "0.23.1", path = "./crates/transport-nats", default-features = false }
wrpc-transport-nats = { version = "0.23.2", path = "./crates/transport-nats", default-features = false }
wrpc-transport-quic = { version = "0.1.2", path = "./crates/transport-quic", default-features = false }
wrpc-wasmtime-nats-cli = { version = "0.8", path = "./crates/wasmtime-nats-cli", default-features = false }
2 changes: 1 addition & 1 deletion crates/transport-nats/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "wrpc-transport-nats"
version = "0.23.1"
version = "0.23.2"
description = "wRPC NATS transport"

authors.workspace = true
Expand Down
53 changes: 45 additions & 8 deletions crates/transport-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use futures::{Stream, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::sync::oneshot;
use tokio::try_join;
use tracing::{debug, instrument, trace, warn};
use tracing::{debug, error, instrument, trace, warn};
use wrpc_transport::Index as _;

pub const PROTOCOL: &str = "wrpc.0.0.1";
Expand Down Expand Up @@ -357,6 +357,10 @@ impl AsyncRead for Reader {
match self.incoming.poll_next_unpin(cx) {
Poll::Ready(Some(Message { mut payload, .. })) => {
trace!(?payload, "received message");
if payload.is_empty() {
trace!("received stream shutdown message");
return Poll::Ready(Ok(()));
}
if payload.len() > cap {
trace!(len = payload.len(), cap, "partially reading the message");
buf.put_slice(&payload.split_to(cap));
Expand All @@ -380,11 +384,16 @@ impl AsyncRead for Reader {
pub struct SubjectWriter {
nats: async_nats::Client,
tx: Subject,
shutdown: bool,
}

impl SubjectWriter {
fn new(nats: async_nats::Client, tx: Subject) -> Self {
Self { nats, tx }
Self {
nats,
tx,
shutdown: false,
}
}
}

Expand All @@ -395,6 +404,7 @@ impl wrpc_transport::Index<Self> for SubjectWriter {
Ok(Self {
nats: self.nats.clone(),
tx,
shutdown: false,
})
}
}
Expand Down Expand Up @@ -450,18 +460,45 @@ impl AsyncWrite for SubjectWriter {

#[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
trace!("writing empty buffer to shut down stream");
trace!("writing stream shutdown message");
ready!(self.as_mut().poll_write(cx, &[]))?;
self.shutdown = true;
Poll::Ready(Ok(()))
}
}

impl Drop for SubjectWriter {
fn drop(&mut self) {
if !self.shutdown {
let nats = self.nats.clone();
let subject = mem::replace(&mut self.tx, Subject::from_static(""));
let fut = async move {
trace!("writing stream shutdown message");
if let Err(err) = nats.publish(subject, Bytes::default()).await {
warn!(?err, "failed to publish stream shutdown message")
}
};
match tokio::runtime::Handle::try_current() {
Ok(rt) => {
rt.spawn(fut);
}
Err(_) => match tokio::runtime::Runtime::new() {
Ok(rt) => {
rt.spawn(fut);
}
Err(err) => error!(?err, "failed to create a new Tokio runtime"),
},
}
}
}
}

#[derive(Debug, Default)]
pub enum RootParamWriter {
#[default]
Corrupted,
Handshaking {
tx: SubjectWriter,
nats: async_nats::Client,
sub: Subscriber,
indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
buffer: Bytes,
Expand All @@ -474,9 +511,9 @@ pub enum RootParamWriter {
}

impl RootParamWriter {
fn new(tx: SubjectWriter, sub: Subscriber, buffer: Bytes) -> Self {
fn new(nats: async_nats::Client, sub: Subscriber, buffer: Bytes) -> Self {
Self::Handshaking {
tx,
nats,
sub,
indexed: std::sync::Mutex::default(),
buffer,
Expand Down Expand Up @@ -520,7 +557,7 @@ impl RootParamWriter {
reply: Some(tx), ..
})) => {
let Self::Handshaking {
tx: SubjectWriter { nats, .. },
nats,
indexed,
buffer,
..
Expand Down Expand Up @@ -899,7 +936,7 @@ impl wrpc_transport::Invoke for Client {
.context("failed to send handshake")?;
Ok((
ParamWriter::Root(RootParamWriter::new(
SubjectWriter::new((*self.nats).clone(), param_tx),
(*self.nats).clone(),
handshake_rx,
params,
)),
Expand Down
71 changes: 68 additions & 3 deletions tests/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use core::str;
use core::time::Duration;

use std::sync::Arc;
use std::thread;

use anyhow::Context;
use bytes::Bytes;
Expand Down Expand Up @@ -576,7 +577,7 @@ where
#[instrument(skip_all, ret)]
async fn assert_dynamic<C, I, S>(clt: Arc<I>, srv: Arc<S>) -> anyhow::Result<()>
where
C: Send + Sync + Default,
C: Send + Sync + Default + 'static,
I: wrpc::Invoke<Context = C>,
S: wrpc::Serve<Context = C>,
{
Expand All @@ -598,13 +599,77 @@ where
)
.await
.context("failed to serve `test.async`")?;
let reset_inv = srv
.serve_values::<(String,), (String,)>("test", "reset", [Box::default(); 0])
.await
.context("failed to serve `test.reset`")?;
let sync_inv = srv
.serve_values("test", "sync", [Box::default(); 0])
.await
.context("failed to serve `test.sync`")?;

let mut async_inv = pin!(async_inv);
let mut reset_inv = pin!(reset_inv);
let mut sync_inv = pin!(sync_inv);

join!(
async {
info!("receiving `test.reset` parameters");
reset_inv
.try_next()
.await
.expect("failed to accept invocation")
.expect("unexpected end of stream");
info!("receiving `test.reset` parameters");
reset_inv
.try_next()
.await
.expect("failed to accept invocation")
.expect("unexpected end of stream");
let inv = reset_inv
.try_next()
.await
.expect("failed to accept invocation")
.expect("unexpected end of stream");
thread::spawn(|| inv);
anyhow::Ok(())
}
.instrument(info_span!("server")),
async {
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
"test",
"reset",
("arg",),
&[[]; 0],
)
.await
.expect_err("`test.reset` should have failed");
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
"test",
"reset",
("arg",),
&[[]; 0],
)
.await
.expect_err("`test.reset` should have failed");
info!("invoking `test.reset`");
clt.invoke_values_blocking::<_, _, (String,)>(
C::default(),
"test",
"reset",
("arg",),
&[[]; 0],
)
.await
.expect_err("`test.reset` should have failed");
}
.instrument(info_span!("client")),
);

join!(
async {
info!("receiving `test.sync` parameters");
Expand Down Expand Up @@ -730,7 +795,6 @@ where
assert_eq!(m, "test");
assert_eq!(n, [[b"foo"]]);
assert_eq!(o, [Some(vec![Ok(Some(String::from("bar")))])]);
info!("finishing `test.sync` session");
}
.instrument(info_span!("client")),
);
Expand Down Expand Up @@ -867,6 +931,7 @@ where
}
.instrument(info_span!("client")),
);

Ok(())
}

Expand Down Expand Up @@ -942,7 +1007,7 @@ async fn rust_dynamic_quic() -> anyhow::Result<()> {
use core::pin::pin;

common::with_quic(
&["sync.test", "async.test"],
&["sync.test", "async.test", "reset.test"],
|port, clt_ep, srv_ep| async move {
let clt = wrpc_transport_quic::Client::new(clt_ep, (Ipv6Addr::LOCALHOST, port));
let srv = wrpc_transport_quic::Server::default();
Expand Down
Loading