Skip to content

Commit 945bb78

Browse files
highkerfacebook-github-bot
authored andcommitted
sync flush logs upon mesh stop (#885)
Summary: Pull Request resolved: #885 force sync flush upon mesh stop Differential Revision: D80310284
1 parent 76a7fe8 commit 945bb78

File tree

4 files changed

+271
-25
lines changed

4 files changed

+271
-25
lines changed

monarch_extension/src/logging.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
#![allow(unsafe_op_in_unsafe_fn)]
1010

11+
use std::time::Duration;
12+
1113
use hyperactor::ActorHandle;
14+
use hyperactor::clock::Clock;
15+
use hyperactor::clock::RealClock;
1216
use hyperactor_mesh::RootActorMesh;
1317
use hyperactor_mesh::actor_mesh::ActorMesh;
1418
use hyperactor_mesh::logging::LogClientActor;
@@ -27,6 +31,8 @@ use pyo3::Bound;
2731
use pyo3::prelude::*;
2832
use pyo3::types::PyModule;
2933

34+
static FLUSH_TIMEOUT: Duration = Duration::from_secs(30);
35+
3036
#[pyclass(
3137
frozen,
3238
name = "LoggingMeshClient",
@@ -89,6 +95,38 @@ impl LoggingMeshClient {
8995
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
9096
let flush_mesh = proc_mesh.spawn("log_flusher", &()).await?;
9197
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
98+
99+
// Register flush_internal as a on-stop callback
100+
let client_actor_for_callback = client_actor.clone();
101+
let flush_mesh_for_callback = flush_mesh.clone();
102+
proc_mesh
103+
.register_onstop_callback(|| async move {
104+
match RealClock
105+
.timeout(
106+
FLUSH_TIMEOUT,
107+
Self::flush_internal(
108+
client_actor_for_callback,
109+
flush_mesh_for_callback,
110+
),
111+
)
112+
.await
113+
{
114+
Ok(Ok(())) => {
115+
tracing::debug!("flush completed successfully during shutdown");
116+
}
117+
Ok(Err(e)) => {
118+
tracing::error!("error during flush: {}", e);
119+
}
120+
Err(_) => {
121+
tracing::error!(
122+
"flush timed out after {} seconds during shutdown",
123+
FLUSH_TIMEOUT.as_secs()
124+
);
125+
}
126+
}
127+
})
128+
.await?;
129+
92130
Ok(Self {
93131
forwarder_mesh,
94132
flush_mesh,
@@ -107,7 +145,7 @@ impl LoggingMeshClient {
107145
) -> PyResult<()> {
108146
if aggregate_window_sec.is_some() && !stream_to_client {
109147
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
110-
"Cannot set aggregate window without streaming to client".to_string(),
148+
"cannot set aggregate window without streaming to client".to_string(),
111149
));
112150
}
113151

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 227 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ use pyo3::types::PyType;
3737
use tokio::sync::Mutex;
3838
use tokio::sync::mpsc;
3939

40+
type OnStopCallback = Box<dyn FnOnce() -> Box<dyn std::future::Future<Output = ()> + Send> + Send>;
41+
4042
use crate::actor_mesh::PythonActorMesh;
4143
use crate::alloc::PyAlloc;
4244
use crate::mailbox::PyMailbox;
@@ -51,6 +53,7 @@ pub struct TrackedProcMesh {
5153
inner: SharedCellRef<ProcMesh>,
5254
cell: SharedCell<ProcMesh>,
5355
children: SharedCellPool,
56+
onstop_callbacks: Arc<Mutex<Vec<OnStopCallback>>>,
5457
}
5558

5659
impl Debug for TrackedProcMesh {
@@ -73,6 +76,7 @@ impl From<ProcMesh> for TrackedProcMesh {
7376
inner,
7477
cell,
7578
children: SharedCellPool::new(),
79+
onstop_callbacks: Arc::new(Mutex::new(Vec::new())),
7680
}
7781
}
7882
}
@@ -103,8 +107,25 @@ impl TrackedProcMesh {
103107
self.inner.client_proc()
104108
}
105109

106-
pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
107-
(self.cell, self.children)
110+
pub fn into_inner(
111+
self,
112+
) -> (
113+
SharedCell<ProcMesh>,
114+
SharedCellPool,
115+
Arc<Mutex<Vec<OnStopCallback>>>,
116+
) {
117+
(self.cell, self.children, self.onstop_callbacks)
118+
}
119+
120+
/// Register a callback to be called when this TrackedProcMesh is stopped
121+
pub async fn register_onstop_callback<F, Fut>(&self, callback: F) -> Result<(), anyhow::Error>
122+
where
123+
F: FnOnce() -> Fut + Send + 'static,
124+
Fut: std::future::Future<Output = ()> + Send + 'static,
125+
{
126+
let mut callbacks = self.onstop_callbacks.lock().await;
127+
callbacks.push(Box::new(|| Box::new(callback())));
128+
Ok(())
108129
}
109130
}
110131

@@ -226,7 +247,17 @@ impl PyProcMesh {
226247
let tracked_proc_mesh = inner.take().await.map_err(|e| {
227248
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
228249
})?;
229-
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
250+
let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner();
251+
252+
// Call all registered drop callbacks before stopping
253+
let mut callbacks = drop_callbacks.lock().await;
254+
let callbacks_to_call = callbacks.drain(..).collect::<Vec<_>>();
255+
drop(callbacks); // Release the lock
256+
257+
for callback in callbacks_to_call {
258+
let future = callback();
259+
std::pin::Pin::from(future).await;
260+
}
230261

231262
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
232263
children.discard_all().await?;
@@ -440,3 +471,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
440471
hyperactor_mod.add_class::<PyProcEvent>()?;
441472
Ok(())
442473
}
474+
475+
#[cfg(test)]
476+
mod tests {
477+
use std::sync::Arc;
478+
use std::sync::atomic::AtomicBool;
479+
use std::sync::atomic::AtomicU32;
480+
use std::sync::atomic::Ordering;
481+
482+
use anyhow::Result;
483+
use hyperactor_mesh::alloc::AllocSpec;
484+
use hyperactor_mesh::alloc::Allocator;
485+
use hyperactor_mesh::alloc::local::LocalAllocator;
486+
use hyperactor_mesh::proc_mesh::ProcMesh;
487+
use ndslice::extent;
488+
use tokio::sync::Mutex;
489+
490+
use super::*;
491+
492+
#[tokio::test]
493+
async fn test_register_onstop_callback_single() -> Result<()> {
494+
// Create a TrackedProcMesh
495+
let alloc = LocalAllocator
496+
.allocate(AllocSpec {
497+
extent: extent! { replica = 1 },
498+
constraints: Default::default(),
499+
})
500+
.await?;
501+
502+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
503+
504+
// Extract events before wrapping in TrackedProcMesh
505+
let events = proc_mesh.events().unwrap();
506+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
507+
508+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
509+
510+
// Create a flag to track if callback was executed
511+
let callback_executed = Arc::new(AtomicBool::new(false));
512+
let callback_executed_clone = callback_executed.clone();
513+
514+
// Register a callback
515+
tracked_proc_mesh
516+
.register_onstop_callback(move || {
517+
let flag = callback_executed_clone.clone();
518+
async move {
519+
flag.store(true, Ordering::SeqCst);
520+
}
521+
})
522+
.await?;
523+
524+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
525+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
526+
527+
// Call stop_mesh (this should trigger the callback)
528+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
529+
530+
// Verify the callback was executed
531+
assert!(
532+
callback_executed.load(Ordering::SeqCst),
533+
"Callback should have been executed"
534+
);
535+
536+
Ok(())
537+
}
538+
539+
#[tokio::test]
540+
async fn test_register_onstop_callback_multiple() -> Result<()> {
541+
// Create a TrackedProcMesh
542+
let alloc = LocalAllocator
543+
.allocate(AllocSpec {
544+
extent: extent! { replica = 1 },
545+
constraints: Default::default(),
546+
})
547+
.await?;
548+
549+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
550+
551+
// Extract events before wrapping in TrackedProcMesh
552+
let events = proc_mesh.events().unwrap();
553+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
554+
555+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
556+
557+
// Create counters to track callback executions
558+
let callback_count = Arc::new(AtomicU32::new(0));
559+
let execution_order = Arc::new(Mutex::new(Vec::<u32>::new()));
560+
561+
// Register multiple callbacks
562+
for i in 1..=3 {
563+
let count = callback_count.clone();
564+
let order = execution_order.clone();
565+
tracked_proc_mesh
566+
.register_onstop_callback(move || {
567+
let count_clone = count.clone();
568+
let order_clone = order.clone();
569+
async move {
570+
count_clone.fetch_add(1, Ordering::SeqCst);
571+
let mut order_vec = order_clone.lock().await;
572+
order_vec.push(i);
573+
}
574+
})
575+
.await?;
576+
}
577+
578+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
579+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
580+
581+
// Call stop_mesh (this should trigger all callbacks)
582+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
583+
584+
// Verify all callbacks were executed
585+
assert_eq!(
586+
callback_count.load(Ordering::SeqCst),
587+
3,
588+
"All 3 callbacks should have been executed"
589+
);
590+
591+
// Verify execution order (callbacks should be executed in registration order)
592+
let order_vec = execution_order.lock().await;
593+
assert_eq!(
594+
*order_vec,
595+
vec![1, 2, 3],
596+
"Callbacks should be executed in registration order"
597+
);
598+
599+
Ok(())
600+
}
601+
602+
#[tokio::test]
603+
async fn test_register_onstop_callback_error_handling() -> Result<()> {
604+
// Create a TrackedProcMesh
605+
let alloc = LocalAllocator
606+
.allocate(AllocSpec {
607+
extent: extent! { replica = 1 },
608+
constraints: Default::default(),
609+
})
610+
.await?;
611+
612+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
613+
614+
// Extract events before wrapping in TrackedProcMesh
615+
let events = proc_mesh.events().unwrap();
616+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
617+
618+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
619+
620+
// Create flags to track callback executions
621+
let callback1_executed = Arc::new(AtomicBool::new(false));
622+
let callback2_executed = Arc::new(AtomicBool::new(false));
623+
624+
let callback1_executed_clone = callback1_executed.clone();
625+
let callback2_executed_clone = callback2_executed.clone();
626+
627+
// Register a callback that panics
628+
tracked_proc_mesh
629+
.register_onstop_callback(move || {
630+
let flag = callback1_executed_clone.clone();
631+
async move {
632+
flag.store(true, Ordering::SeqCst);
633+
// This callback completes successfully
634+
}
635+
})
636+
.await?;
637+
638+
// Register another callback that should still execute even if the first one had issues
639+
tracked_proc_mesh
640+
.register_onstop_callback(move || {
641+
let flag = callback2_executed_clone.clone();
642+
async move {
643+
flag.store(true, Ordering::SeqCst);
644+
}
645+
})
646+
.await?;
647+
648+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
649+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
650+
651+
// Call stop_mesh (this should trigger both callbacks)
652+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
653+
654+
// Verify both callbacks were executed
655+
assert!(
656+
callback1_executed.load(Ordering::SeqCst),
657+
"First callback should have been executed"
658+
);
659+
assert!(
660+
callback2_executed.load(Ordering::SeqCst),
661+
"Second callback should have been executed"
662+
);
663+
664+
Ok(())
665+
}
666+
}

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111

1212
import click
13-
from monarch._src.actor.future import Future
1413

1514
from monarch.actor import Actor, endpoint, proc_mesh
1615

@@ -41,10 +40,7 @@ async def _flush_logs() -> None:
4140
for _ in range(5):
4241
await am.print.call("has print streaming")
4342

44-
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
45-
log_mesh = pm._logging_mesh_client
46-
assert log_mesh is not None
47-
Future(coro=log_mesh.flush().spawn().task()).get()
43+
await pm.stop()
4844

4945

5046
@main.command("flush-logs")

0 commit comments

Comments
 (0)