Skip to content

Commit 116a003

Browse files
highkerfacebook-github-bot
authored andcommitted
sync flush logs upon mesh stop (#885)
Summary: force sync flush upon mesh stop Reviewed By: vidhyav Differential Revision: D80310284
1 parent d43b3ed commit 116a003

File tree

4 files changed

+271
-26
lines changed

4 files changed

+271
-26
lines changed

monarch_extension/src/logging.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
#![allow(unsafe_op_in_unsafe_fn)]
1010

11+
use std::time::Duration;
12+
1113
use hyperactor::ActorHandle;
14+
use hyperactor::clock::Clock;
15+
use hyperactor::clock::RealClock;
1216
use hyperactor_mesh::RootActorMesh;
1317
use hyperactor_mesh::actor_mesh::ActorMesh;
1418
use hyperactor_mesh::logging::LogClientActor;
@@ -25,6 +29,8 @@ use pyo3::Bound;
2529
use pyo3::prelude::*;
2630
use pyo3::types::PyModule;
2731

32+
static FLUSH_TIMEOUT: Duration = Duration::from_secs(30);
33+
2834
#[pyclass(
2935
frozen,
3036
name = "LoggingMeshClient",
@@ -86,6 +92,38 @@ impl LoggingMeshClient {
8692
let client_actor_ref = client_actor.bind();
8793
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
8894
let logger_mesh = proc_mesh.spawn("logger", &()).await?;
95+
96+
// Register flush_internal as a on-stop callback
97+
let client_actor_for_callback = client_actor.clone();
98+
let forwarder_mesh_for_callback = forwarder_mesh.clone();
99+
proc_mesh
100+
.register_onstop_callback(|| async move {
101+
match RealClock
102+
.timeout(
103+
FLUSH_TIMEOUT,
104+
Self::flush_internal(
105+
client_actor_for_callback,
106+
forwarder_mesh_for_callback,
107+
),
108+
)
109+
.await
110+
{
111+
Ok(Ok(())) => {
112+
tracing::debug!("flush completed successfully during shutdown");
113+
}
114+
Ok(Err(e)) => {
115+
tracing::error!("error during flush: {}", e);
116+
}
117+
Err(_) => {
118+
tracing::error!(
119+
"flush timed out after {} seconds during shutdown",
120+
FLUSH_TIMEOUT.as_secs()
121+
);
122+
}
123+
}
124+
})
125+
.await?;
126+
89127
Ok(Self {
90128
forwarder_mesh,
91129
logger_mesh,
@@ -103,7 +141,7 @@ impl LoggingMeshClient {
103141
) -> PyResult<()> {
104142
if aggregate_window_sec.is_some() && !stream_to_client {
105143
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
106-
"Cannot set aggregate window without streaming to client".to_string(),
144+
"cannot set aggregate window without streaming to client".to_string(),
107145
));
108146
}
109147

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 227 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ use hyperactor::WorldId;
1818
use hyperactor::actor::RemoteActor;
1919
use hyperactor::proc::Proc;
2020
use hyperactor_mesh::RootActorMesh;
21-
use hyperactor_mesh::alloc::Alloc;
2221
use hyperactor_mesh::alloc::ProcStopReason;
2322
use hyperactor_mesh::proc_mesh::ProcEvent;
2423
use hyperactor_mesh::proc_mesh::ProcEvents;
@@ -38,6 +37,8 @@ use pyo3::types::PyType;
3837
use tokio::sync::Mutex;
3938
use tokio::sync::mpsc;
4039

40+
type OnStopCallback = Box<dyn FnOnce() -> Box<dyn std::future::Future<Output = ()> + Send> + Send>;
41+
4142
use crate::actor_mesh::PythonActorMesh;
4243
use crate::actor_mesh::PythonActorMeshImpl;
4344
use crate::alloc::PyAlloc;
@@ -55,6 +56,7 @@ pub struct TrackedProcMesh {
5556
inner: SharedCellRef<ProcMesh>,
5657
cell: SharedCell<ProcMesh>,
5758
children: SharedCellPool,
59+
onstop_callbacks: Arc<Mutex<Vec<OnStopCallback>>>,
5860
}
5961

6062
impl Debug for TrackedProcMesh {
@@ -77,6 +79,7 @@ impl From<ProcMesh> for TrackedProcMesh {
7779
inner,
7880
cell,
7981
children: SharedCellPool::new(),
82+
onstop_callbacks: Arc::new(Mutex::new(Vec::new())),
8083
}
8184
}
8285
}
@@ -107,8 +110,25 @@ impl TrackedProcMesh {
107110
self.inner.client_proc()
108111
}
109112

110-
pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
111-
(self.cell, self.children)
113+
pub fn into_inner(
114+
self,
115+
) -> (
116+
SharedCell<ProcMesh>,
117+
SharedCellPool,
118+
Arc<Mutex<Vec<OnStopCallback>>>,
119+
) {
120+
(self.cell, self.children, self.onstop_callbacks)
121+
}
122+
123+
/// Register a callback to be called when this TrackedProcMesh is stopped
124+
pub async fn register_onstop_callback<F, Fut>(&self, callback: F) -> Result<(), anyhow::Error>
125+
where
126+
F: FnOnce() -> Fut + Send + 'static,
127+
Fut: std::future::Future<Output = ()> + Send + 'static,
128+
{
129+
let mut callbacks = self.onstop_callbacks.lock().await;
130+
callbacks.push(Box::new(|| Box::new(callback())));
131+
Ok(())
112132
}
113133
}
114134

@@ -230,7 +250,17 @@ impl PyProcMesh {
230250
let tracked_proc_mesh = inner.take().await.map_err(|e| {
231251
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
232252
})?;
233-
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
253+
let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner();
254+
255+
// Call all registered drop callbacks before stopping
256+
let mut callbacks = drop_callbacks.lock().await;
257+
let callbacks_to_call = callbacks.drain(..).collect::<Vec<_>>();
258+
drop(callbacks); // Release the lock
259+
260+
for callback in callbacks_to_call {
261+
let future = callback();
262+
std::pin::Pin::from(future).await;
263+
}
234264

235265
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236266
// Discarding actor meshes that have been individually stopped will result in an expected error
@@ -488,3 +518,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
488518
hyperactor_mod.add_class::<PyProcEvent>()?;
489519
Ok(())
490520
}
521+
522+
#[cfg(test)]
523+
mod tests {
524+
use std::sync::Arc;
525+
use std::sync::atomic::AtomicBool;
526+
use std::sync::atomic::AtomicU32;
527+
use std::sync::atomic::Ordering;
528+
529+
use anyhow::Result;
530+
use hyperactor_mesh::alloc::AllocSpec;
531+
use hyperactor_mesh::alloc::Allocator;
532+
use hyperactor_mesh::alloc::local::LocalAllocator;
533+
use hyperactor_mesh::proc_mesh::ProcMesh;
534+
use ndslice::extent;
535+
use tokio::sync::Mutex;
536+
537+
use super::*;
538+
539+
#[tokio::test]
540+
async fn test_register_onstop_callback_single() -> Result<()> {
541+
// Create a TrackedProcMesh
542+
let alloc = LocalAllocator
543+
.allocate(AllocSpec {
544+
extent: extent! { replica = 1 },
545+
constraints: Default::default(),
546+
})
547+
.await?;
548+
549+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
550+
551+
// Extract events before wrapping in TrackedProcMesh
552+
let events = proc_mesh.events().unwrap();
553+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
554+
555+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
556+
557+
// Create a flag to track if callback was executed
558+
let callback_executed = Arc::new(AtomicBool::new(false));
559+
let callback_executed_clone = callback_executed.clone();
560+
561+
// Register a callback
562+
tracked_proc_mesh
563+
.register_onstop_callback(move || {
564+
let flag = callback_executed_clone.clone();
565+
async move {
566+
flag.store(true, Ordering::SeqCst);
567+
}
568+
})
569+
.await?;
570+
571+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
572+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
573+
574+
// Call stop_mesh (this should trigger the callback)
575+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
576+
577+
// Verify the callback was executed
578+
assert!(
579+
callback_executed.load(Ordering::SeqCst),
580+
"Callback should have been executed"
581+
);
582+
583+
Ok(())
584+
}
585+
586+
#[tokio::test]
587+
async fn test_register_onstop_callback_multiple() -> Result<()> {
588+
// Create a TrackedProcMesh
589+
let alloc = LocalAllocator
590+
.allocate(AllocSpec {
591+
extent: extent! { replica = 1 },
592+
constraints: Default::default(),
593+
})
594+
.await?;
595+
596+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
597+
598+
// Extract events before wrapping in TrackedProcMesh
599+
let events = proc_mesh.events().unwrap();
600+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
601+
602+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
603+
604+
// Create counters to track callback executions
605+
let callback_count = Arc::new(AtomicU32::new(0));
606+
let execution_order = Arc::new(Mutex::new(Vec::<u32>::new()));
607+
608+
// Register multiple callbacks
609+
for i in 1..=3 {
610+
let count = callback_count.clone();
611+
let order = execution_order.clone();
612+
tracked_proc_mesh
613+
.register_onstop_callback(move || {
614+
let count_clone = count.clone();
615+
let order_clone = order.clone();
616+
async move {
617+
count_clone.fetch_add(1, Ordering::SeqCst);
618+
let mut order_vec = order_clone.lock().await;
619+
order_vec.push(i);
620+
}
621+
})
622+
.await?;
623+
}
624+
625+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
626+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
627+
628+
// Call stop_mesh (this should trigger all callbacks)
629+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
630+
631+
// Verify all callbacks were executed
632+
assert_eq!(
633+
callback_count.load(Ordering::SeqCst),
634+
3,
635+
"All 3 callbacks should have been executed"
636+
);
637+
638+
// Verify execution order (callbacks should be executed in registration order)
639+
let order_vec = execution_order.lock().await;
640+
assert_eq!(
641+
*order_vec,
642+
vec![1, 2, 3],
643+
"Callbacks should be executed in registration order"
644+
);
645+
646+
Ok(())
647+
}
648+
649+
#[tokio::test]
650+
async fn test_register_onstop_callback_error_handling() -> Result<()> {
651+
// Create a TrackedProcMesh
652+
let alloc = LocalAllocator
653+
.allocate(AllocSpec {
654+
extent: extent! { replica = 1 },
655+
constraints: Default::default(),
656+
})
657+
.await?;
658+
659+
let mut proc_mesh = ProcMesh::allocate(alloc).await?;
660+
661+
// Extract events before wrapping in TrackedProcMesh
662+
let events = proc_mesh.events().unwrap();
663+
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));
664+
665+
let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);
666+
667+
// Create flags to track callback executions
668+
let callback1_executed = Arc::new(AtomicBool::new(false));
669+
let callback2_executed = Arc::new(AtomicBool::new(false));
670+
671+
let callback1_executed_clone = callback1_executed.clone();
672+
let callback2_executed_clone = callback2_executed.clone();
673+
674+
// Register a callback that panics
675+
tracked_proc_mesh
676+
.register_onstop_callback(move || {
677+
let flag = callback1_executed_clone.clone();
678+
async move {
679+
flag.store(true, Ordering::SeqCst);
680+
// This callback completes successfully
681+
}
682+
})
683+
.await?;
684+
685+
// Register another callback that should still execute even if the first one had issues
686+
tracked_proc_mesh
687+
.register_onstop_callback(move || {
688+
let flag = callback2_executed_clone.clone();
689+
async move {
690+
flag.store(true, Ordering::SeqCst);
691+
}
692+
})
693+
.await?;
694+
695+
// Create a SharedCell<TrackedProcMesh> for stop_mesh
696+
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);
697+
698+
// Call stop_mesh (this should trigger both callbacks)
699+
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;
700+
701+
// Verify both callbacks were executed
702+
assert!(
703+
callback1_executed.load(Ordering::SeqCst),
704+
"First callback should have been executed"
705+
);
706+
assert!(
707+
callback2_executed.load(Ordering::SeqCst),
708+
"Second callback should have been executed"
709+
);
710+
711+
Ok(())
712+
}
713+
}

python/tests/python_actor_test_binary.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111

1212
import click
13-
from monarch._src.actor.future import Future
1413

1514
from monarch.actor import Actor, endpoint, proc_mesh
1615

@@ -41,10 +40,7 @@ async def _flush_logs() -> None:
4140
for _ in range(5):
4241
await am.print.call("has print streaming")
4342

44-
# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
45-
log_mesh = pm._logging_manager._logging_mesh_client
46-
assert log_mesh is not None
47-
Future(coro=log_mesh.flush().spawn().task()).get()
43+
await pm.stop()
4844

4945

5046
@main.command("flush-logs")

0 commit comments

Comments
 (0)