Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion monarch_extension/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

#![allow(unsafe_op_in_unsafe_fn)]

use std::time::Duration;

use hyperactor::ActorHandle;
use hyperactor::clock::Clock;
use hyperactor::clock::RealClock;
use hyperactor_mesh::RootActorMesh;
use hyperactor_mesh::actor_mesh::ActorMesh;
use hyperactor_mesh::logging::LogClientActor;
Expand All @@ -25,6 +29,8 @@ use pyo3::Bound;
use pyo3::prelude::*;
use pyo3::types::PyModule;

static FLUSH_TIMEOUT: Duration = Duration::from_secs(30);

#[pyclass(
frozen,
name = "LoggingMeshClient",
Expand Down Expand Up @@ -86,6 +92,38 @@ impl LoggingMeshClient {
let client_actor_ref = client_actor.bind();
let forwarder_mesh = proc_mesh.spawn("log_forwarder", &client_actor_ref).await?;
let logger_mesh = proc_mesh.spawn("logger", &()).await?;

// Register flush_internal as a on-stop callback
let client_actor_for_callback = client_actor.clone();
let forwarder_mesh_for_callback = forwarder_mesh.clone();
proc_mesh
.register_onstop_callback(|| async move {
match RealClock
.timeout(
FLUSH_TIMEOUT,
Self::flush_internal(
client_actor_for_callback,
forwarder_mesh_for_callback,
),
)
.await
{
Ok(Ok(())) => {
tracing::debug!("flush completed successfully during shutdown");
}
Ok(Err(e)) => {
tracing::error!("error during flush: {}", e);
}
Err(_) => {
tracing::error!(
"flush timed out after {} seconds during shutdown",
FLUSH_TIMEOUT.as_secs()
);
}
}
})
.await?;

Ok(Self {
forwarder_mesh,
logger_mesh,
Expand All @@ -103,7 +141,7 @@ impl LoggingMeshClient {
) -> PyResult<()> {
if aggregate_window_sec.is_some() && !stream_to_client {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Cannot set aggregate window without streaming to client".to_string(),
"cannot set aggregate window without streaming to client".to_string(),
));
}

Expand Down
231 changes: 227 additions & 4 deletions monarch_hyperactor/src/proc_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use hyperactor::WorldId;
use hyperactor::actor::RemoteActor;
use hyperactor::proc::Proc;
use hyperactor_mesh::RootActorMesh;
use hyperactor_mesh::alloc::Alloc;
use hyperactor_mesh::alloc::ProcStopReason;
use hyperactor_mesh::proc_mesh::ProcEvent;
use hyperactor_mesh::proc_mesh::ProcEvents;
Expand All @@ -38,6 +37,8 @@ use pyo3::types::PyType;
use tokio::sync::Mutex;
use tokio::sync::mpsc;

type OnStopCallback = Box<dyn FnOnce() -> Box<dyn std::future::Future<Output = ()> + Send> + Send>;

use crate::actor_mesh::PythonActorMesh;
use crate::actor_mesh::PythonActorMeshImpl;
use crate::alloc::PyAlloc;
Expand All @@ -55,6 +56,7 @@ pub struct TrackedProcMesh {
inner: SharedCellRef<ProcMesh>,
cell: SharedCell<ProcMesh>,
children: SharedCellPool,
onstop_callbacks: Arc<Mutex<Vec<OnStopCallback>>>,
}

impl Debug for TrackedProcMesh {
Expand All @@ -77,6 +79,7 @@ impl From<ProcMesh> for TrackedProcMesh {
inner,
cell,
children: SharedCellPool::new(),
onstop_callbacks: Arc::new(Mutex::new(Vec::new())),
}
}
}
Expand Down Expand Up @@ -107,8 +110,25 @@ impl TrackedProcMesh {
self.inner.client_proc()
}

pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
(self.cell, self.children)
pub fn into_inner(
self,
) -> (
SharedCell<ProcMesh>,
SharedCellPool,
Arc<Mutex<Vec<OnStopCallback>>>,
) {
(self.cell, self.children, self.onstop_callbacks)
}

/// Register a callback to be called when this TrackedProcMesh is stopped
pub async fn register_onstop_callback<F, Fut>(&self, callback: F) -> Result<(), anyhow::Error>
where
F: FnOnce() -> Fut + Send + 'static,
Fut: std::future::Future<Output = ()> + Send + 'static,
{
let mut callbacks = self.onstop_callbacks.lock().await;
callbacks.push(Box::new(|| Box::new(callback())));
Ok(())
}
}

Expand Down Expand Up @@ -230,7 +250,17 @@ impl PyProcMesh {
let tracked_proc_mesh = inner.take().await.map_err(|e| {
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
})?;
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
let (proc_mesh, children, drop_callbacks) = tracked_proc_mesh.into_inner();

// Call all registered drop callbacks before stopping
let mut callbacks = drop_callbacks.lock().await;
let callbacks_to_call = callbacks.drain(..).collect::<Vec<_>>();
drop(callbacks); // Release the lock

for callback in callbacks_to_call {
let future = callback();
std::pin::Pin::from(future).await;
}

// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
// Discarding actor meshes that have been individually stopped will result in an expected error
Expand Down Expand Up @@ -488,3 +518,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
hyperactor_mod.add_class::<PyProcEvent>()?;
Ok(())
}

#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::AtomicU32;
use std::sync::atomic::Ordering;

use anyhow::Result;
use hyperactor_mesh::alloc::AllocSpec;
use hyperactor_mesh::alloc::Allocator;
use hyperactor_mesh::alloc::local::LocalAllocator;
use hyperactor_mesh::proc_mesh::ProcMesh;
use ndslice::extent;
use tokio::sync::Mutex;

use super::*;

#[tokio::test]
async fn test_register_onstop_callback_single() -> Result<()> {
// Create a TrackedProcMesh
let alloc = LocalAllocator
.allocate(AllocSpec {
extent: extent! { replica = 1 },
constraints: Default::default(),
})
.await?;

let mut proc_mesh = ProcMesh::allocate(alloc).await?;

// Extract events before wrapping in TrackedProcMesh
let events = proc_mesh.events().unwrap();
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));

let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);

// Create a flag to track if callback was executed
let callback_executed = Arc::new(AtomicBool::new(false));
let callback_executed_clone = callback_executed.clone();

// Register a callback
tracked_proc_mesh
.register_onstop_callback(move || {
let flag = callback_executed_clone.clone();
async move {
flag.store(true, Ordering::SeqCst);
}
})
.await?;

// Create a SharedCell<TrackedProcMesh> for stop_mesh
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);

// Call stop_mesh (this should trigger the callback)
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;

// Verify the callback was executed
assert!(
callback_executed.load(Ordering::SeqCst),
"Callback should have been executed"
);

Ok(())
}

#[tokio::test]
async fn test_register_onstop_callback_multiple() -> Result<()> {
// Create a TrackedProcMesh
let alloc = LocalAllocator
.allocate(AllocSpec {
extent: extent! { replica = 1 },
constraints: Default::default(),
})
.await?;

let mut proc_mesh = ProcMesh::allocate(alloc).await?;

// Extract events before wrapping in TrackedProcMesh
let events = proc_mesh.events().unwrap();
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));

let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);

// Create counters to track callback executions
let callback_count = Arc::new(AtomicU32::new(0));
let execution_order = Arc::new(Mutex::new(Vec::<u32>::new()));

// Register multiple callbacks
for i in 1..=3 {
let count = callback_count.clone();
let order = execution_order.clone();
tracked_proc_mesh
.register_onstop_callback(move || {
let count_clone = count.clone();
let order_clone = order.clone();
async move {
count_clone.fetch_add(1, Ordering::SeqCst);
let mut order_vec = order_clone.lock().await;
order_vec.push(i);
}
})
.await?;
}

// Create a SharedCell<TrackedProcMesh> for stop_mesh
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);

// Call stop_mesh (this should trigger all callbacks)
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;

// Verify all callbacks were executed
assert_eq!(
callback_count.load(Ordering::SeqCst),
3,
"All 3 callbacks should have been executed"
);

// Verify execution order (callbacks should be executed in registration order)
let order_vec = execution_order.lock().await;
assert_eq!(
*order_vec,
vec![1, 2, 3],
"Callbacks should be executed in registration order"
);

Ok(())
}

#[tokio::test]
async fn test_register_onstop_callback_error_handling() -> Result<()> {
// Create a TrackedProcMesh
let alloc = LocalAllocator
.allocate(AllocSpec {
extent: extent! { replica = 1 },
constraints: Default::default(),
})
.await?;

let mut proc_mesh = ProcMesh::allocate(alloc).await?;

// Extract events before wrapping in TrackedProcMesh
let events = proc_mesh.events().unwrap();
let proc_events_cell = SharedCell::from(tokio::sync::Mutex::new(events));

let tracked_proc_mesh = TrackedProcMesh::from(proc_mesh);

// Create flags to track callback executions
let callback1_executed = Arc::new(AtomicBool::new(false));
let callback2_executed = Arc::new(AtomicBool::new(false));

let callback1_executed_clone = callback1_executed.clone();
let callback2_executed_clone = callback2_executed.clone();

// Register a callback that panics
tracked_proc_mesh
.register_onstop_callback(move || {
let flag = callback1_executed_clone.clone();
async move {
flag.store(true, Ordering::SeqCst);
// This callback completes successfully
}
})
.await?;

// Register another callback that should still execute even if the first one had issues
tracked_proc_mesh
.register_onstop_callback(move || {
let flag = callback2_executed_clone.clone();
async move {
flag.store(true, Ordering::SeqCst);
}
})
.await?;

// Create a SharedCell<TrackedProcMesh> for stop_mesh
let tracked_proc_mesh_cell = SharedCell::from(tracked_proc_mesh);

// Call stop_mesh (this should trigger both callbacks)
PyProcMesh::stop_mesh(tracked_proc_mesh_cell, proc_events_cell).await?;

// Verify both callbacks were executed
assert!(
callback1_executed.load(Ordering::SeqCst),
"First callback should have been executed"
);
assert!(
callback2_executed.load(Ordering::SeqCst),
"Second callback should have been executed"
);

Ok(())
}
}
6 changes: 1 addition & 5 deletions python/tests/python_actor_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import logging

import click
from monarch._src.actor.future import Future

from monarch.actor import Actor, endpoint, proc_mesh

Expand Down Expand Up @@ -41,10 +40,7 @@ async def _flush_logs() -> None:
for _ in range(5):
await am.print.call("has print streaming")

# TODO: remove this completely once we hook the flush logic upon dropping device_mesh
log_mesh = pm._logging_manager._logging_mesh_client
assert log_mesh is not None
Future(coro=log_mesh.flush().spawn().task()).get()
await pm.stop()


@main.command("flush-logs")
Expand Down
Loading