Skip to content

Commit 7167e4f

Browse files
committed
fix(rs-nats): close streams on drop
Signed-off-by: Roman Volosatovs <[email protected]>
1 parent 6666de0 commit 7167e4f

File tree

5 files changed

+87
-12
lines changed

5 files changed

+87
-12
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ version = "0.11.0"
66
authors.workspace = true
77
categories.workspace = true
88
edition.workspace = true
9+
homepage.workspace = true
910
license.workspace = true
1011
repository.workspace = true
1112

@@ -140,6 +141,6 @@ wrpc-cli = { version = "0.3", path = "./crates/cli", default-features = false }
140141
wrpc-introspect = { version = "0.3", default-features = false, path = "./crates/introspect" }
141142
wrpc-runtime-wasmtime = { version = "0.22", path = "./crates/runtime-wasmtime", default-features = false }
142143
wrpc-transport = { version = "0.26.8", path = "./crates/transport", default-features = false }
143-
wrpc-transport-nats = { version = "0.23.1", path = "./crates/transport-nats", default-features = false }
144+
wrpc-transport-nats = { version = "0.23.2", path = "./crates/transport-nats", default-features = false }
144145
wrpc-transport-quic = { version = "0.1.2", path = "./crates/transport-quic", default-features = false }
145146
wrpc-wasmtime-nats-cli = { version = "0.8", path = "./crates/wasmtime-nats-cli", default-features = false }

crates/transport-nats/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "wrpc-transport-nats"
3-
version = "0.23.1"
3+
version = "0.23.2"
44
description = "wRPC NATS transport"
55

66
authors.workspace = true

crates/transport-nats/src/lib.rs

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,10 @@ impl AsyncRead for Reader {
357357
match self.incoming.poll_next_unpin(cx) {
358358
Poll::Ready(Some(Message { mut payload, .. })) => {
359359
trace!(?payload, "received message");
360+
if payload.is_empty() {
361+
trace!("received stream shutdown message");
362+
return Poll::Ready(Ok(()));
363+
}
360364
if payload.len() > cap {
361365
trace!(len = payload.len(), cap, "partially reading the message");
362366
buf.put_slice(&payload.split_to(cap));
@@ -380,11 +384,16 @@ impl AsyncRead for Reader {
380384
pub struct SubjectWriter {
381385
nats: async_nats::Client,
382386
tx: Subject,
387+
shutdown: bool,
383388
}
384389

385390
impl SubjectWriter {
386391
fn new(nats: async_nats::Client, tx: Subject) -> Self {
387-
Self { nats, tx }
392+
Self {
393+
nats,
394+
tx,
395+
shutdown: false,
396+
}
388397
}
389398
}
390399

@@ -395,6 +404,7 @@ impl wrpc_transport::Index<Self> for SubjectWriter {
395404
Ok(Self {
396405
nats: self.nats.clone(),
397406
tx,
407+
shutdown: false,
398408
})
399409
}
400410
}
@@ -450,18 +460,34 @@ impl AsyncWrite for SubjectWriter {
450460

451461
#[instrument(level = "trace", skip_all, ret, fields(subject = self.tx.as_str()))]
452462
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
453-
trace!("writing empty buffer to shut down stream");
463+
trace!("writing stream shutdown message");
454464
ready!(self.as_mut().poll_write(cx, &[]))?;
465+
self.shutdown = true;
455466
Poll::Ready(Ok(()))
456467
}
457468
}
458469

470+
impl Drop for SubjectWriter {
471+
fn drop(&mut self) {
472+
if !self.shutdown {
473+
let nats = self.nats.clone();
474+
let subject = mem::replace(&mut self.tx, Subject::from_static(""));
475+
tokio::spawn(async move {
476+
trace!("writing stream shutdown message");
477+
if let Err(err) = nats.publish(subject, Bytes::default()).await {
478+
warn!(?err, "failed to publish stream shutdown message")
479+
}
480+
});
481+
}
482+
}
483+
}
484+
459485
#[derive(Debug, Default)]
460486
pub enum RootParamWriter {
461487
#[default]
462488
Corrupted,
463489
Handshaking {
464-
tx: SubjectWriter,
490+
nats: async_nats::Client,
465491
sub: Subscriber,
466492
indexed: std::sync::Mutex<Vec<(Vec<usize>, oneshot::Sender<SubjectWriter>)>>,
467493
buffer: Bytes,
@@ -474,9 +500,9 @@ pub enum RootParamWriter {
474500
}
475501

476502
impl RootParamWriter {
477-
fn new(tx: SubjectWriter, sub: Subscriber, buffer: Bytes) -> Self {
503+
fn new(nats: async_nats::Client, sub: Subscriber, buffer: Bytes) -> Self {
478504
Self::Handshaking {
479-
tx,
505+
nats,
480506
sub,
481507
indexed: std::sync::Mutex::default(),
482508
buffer,
@@ -520,7 +546,7 @@ impl RootParamWriter {
520546
reply: Some(tx), ..
521547
})) => {
522548
let Self::Handshaking {
523-
tx: SubjectWriter { nats, .. },
549+
nats,
524550
indexed,
525551
buffer,
526552
..
@@ -899,7 +925,7 @@ impl wrpc_transport::Invoke for Client {
899925
.context("failed to send handshake")?;
900926
Ok((
901927
ParamWriter::Root(RootParamWriter::new(
902-
SubjectWriter::new((*self.nats).clone(), param_tx),
928+
(*self.nats).clone(),
903929
handshake_rx,
904930
params,
905931
)),

tests/rust.rs

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,13 +598,61 @@ where
598598
)
599599
.await
600600
.context("failed to serve `test.async`")?;
601+
let reset_inv = srv
602+
.serve_values::<(String,), (String,)>("test", "reset", [Box::default(); 0])
603+
.await
604+
.context("failed to serve `test.reset`")?;
601605
let sync_inv = srv
602606
.serve_values("test", "sync", [Box::default(); 0])
603607
.await
604608
.context("failed to serve `test.sync`")?;
609+
605610
let mut async_inv = pin!(async_inv);
611+
let mut reset_inv = pin!(reset_inv);
606612
let mut sync_inv = pin!(sync_inv);
607613

614+
join!(
615+
async {
616+
info!("receiving `test.reset` parameters");
617+
reset_inv
618+
.try_next()
619+
.await
620+
.expect("failed to accept invocation")
621+
.expect("unexpected end of stream");
622+
info!("receiving `test.reset` parameters");
623+
reset_inv
624+
.try_next()
625+
.await
626+
.expect("failed to accept invocation")
627+
.expect("unexpected end of stream");
628+
anyhow::Ok(())
629+
}
630+
.instrument(info_span!("server")),
631+
async {
632+
info!("invoking `test.reset`");
633+
clt.invoke_values_blocking::<_, _, (String,)>(
634+
C::default(),
635+
"test",
636+
"reset",
637+
("arg",),
638+
&[[]; 0],
639+
)
640+
.await
641+
.expect_err("`test.reset` should have failed");
642+
info!("invoking `test.reset`");
643+
clt.invoke_values_blocking::<_, _, (String,)>(
644+
C::default(),
645+
"test",
646+
"reset",
647+
("arg",),
648+
&[[]; 0],
649+
)
650+
.await
651+
.expect_err("`test.reset` should have failed");
652+
}
653+
.instrument(info_span!("client")),
654+
);
655+
608656
join!(
609657
async {
610658
info!("receiving `test.sync` parameters");
@@ -730,7 +778,6 @@ where
730778
assert_eq!(m, "test");
731779
assert_eq!(n, [[b"foo"]]);
732780
assert_eq!(o, [Some(vec![Ok(Some(String::from("bar")))])]);
733-
info!("finishing `test.sync` session");
734781
}
735782
.instrument(info_span!("client")),
736783
);
@@ -867,6 +914,7 @@ where
867914
}
868915
.instrument(info_span!("client")),
869916
);
917+
870918
Ok(())
871919
}
872920

@@ -942,7 +990,7 @@ async fn rust_dynamic_quic() -> anyhow::Result<()> {
942990
use core::pin::pin;
943991

944992
common::with_quic(
945-
&["sync.test", "async.test"],
993+
&["sync.test", "async.test", "reset.test"],
946994
|port, clt_ep, srv_ep| async move {
947995
let clt = wrpc_transport_quic::Client::new(clt_ep, (Ipv6Addr::LOCALHOST, port));
948996
let srv = wrpc_transport_quic::Server::default();

0 commit comments

Comments
 (0)