@@ -18,7 +18,6 @@ use hyperactor::WorldId;
18
18
use hyperactor:: actor:: RemoteActor ;
19
19
use hyperactor:: proc:: Proc ;
20
20
use hyperactor_mesh:: RootActorMesh ;
21
- use hyperactor_mesh:: alloc:: Alloc ;
22
21
use hyperactor_mesh:: alloc:: ProcStopReason ;
23
22
use hyperactor_mesh:: proc_mesh:: ProcEvent ;
24
23
use hyperactor_mesh:: proc_mesh:: ProcEvents ;
@@ -38,6 +37,8 @@ use pyo3::types::PyType;
38
37
use tokio:: sync:: Mutex ;
39
38
use tokio:: sync:: mpsc;
40
39
40
+ type OnStopCallback = Box < dyn FnOnce ( ) -> Box < dyn std:: future:: Future < Output = ( ) > + Send > + Send > ;
41
+
41
42
use crate :: actor_mesh:: PythonActorMesh ;
42
43
use crate :: actor_mesh:: PythonActorMeshImpl ;
43
44
use crate :: alloc:: PyAlloc ;
@@ -55,6 +56,7 @@ pub struct TrackedProcMesh {
55
56
inner : SharedCellRef < ProcMesh > ,
56
57
cell : SharedCell < ProcMesh > ,
57
58
children : SharedCellPool ,
59
+ onstop_callbacks : Arc < Mutex < Vec < OnStopCallback > > > ,
58
60
}
59
61
60
62
impl Debug for TrackedProcMesh {
@@ -77,6 +79,7 @@ impl From<ProcMesh> for TrackedProcMesh {
77
79
inner,
78
80
cell,
79
81
children : SharedCellPool :: new ( ) ,
82
+ onstop_callbacks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
80
83
}
81
84
}
82
85
}
@@ -107,8 +110,25 @@ impl TrackedProcMesh {
107
110
self . inner . client_proc ( )
108
111
}
109
112
110
- pub fn into_inner ( self ) -> ( SharedCell < ProcMesh > , SharedCellPool ) {
111
- ( self . cell , self . children )
113
+ pub fn into_inner (
114
+ self ,
115
+ ) -> (
116
+ SharedCell < ProcMesh > ,
117
+ SharedCellPool ,
118
+ Arc < Mutex < Vec < OnStopCallback > > > ,
119
+ ) {
120
+ ( self . cell , self . children , self . onstop_callbacks )
121
+ }
122
+
123
+ /// Register a callback to be called when this TrackedProcMesh is stopped
124
+ pub async fn register_onstop_callback < F , Fut > ( & self , callback : F ) -> Result < ( ) , anyhow:: Error >
125
+ where
126
+ F : FnOnce ( ) -> Fut + Send + ' static ,
127
+ Fut : std:: future:: Future < Output = ( ) > + Send + ' static ,
128
+ {
129
+ let mut callbacks = self . onstop_callbacks . lock ( ) . await ;
130
+ callbacks. push ( Box :: new ( || Box :: new ( callback ( ) ) ) ) ;
131
+ Ok ( ( ) )
112
132
}
113
133
}
114
134
@@ -230,7 +250,17 @@ impl PyProcMesh {
230
250
let tracked_proc_mesh = inner. take ( ) . await . map_err ( |e| {
231
251
PyRuntimeError :: new_err ( format ! ( "`ProcMesh` has already been stopped: {}" , e) )
232
252
} ) ?;
233
- let ( proc_mesh, children) = tracked_proc_mesh. into_inner ( ) ;
253
+ let ( proc_mesh, children, drop_callbacks) = tracked_proc_mesh. into_inner ( ) ;
254
+
255
+ // Call all registered drop callbacks before stopping
256
+ let mut callbacks = drop_callbacks. lock ( ) . await ;
257
+ let callbacks_to_call = callbacks. drain ( ..) . collect :: < Vec < _ > > ( ) ;
258
+ drop ( callbacks) ; // Release the lock
259
+
260
+ for callback in callbacks_to_call {
261
+ let future = callback ( ) ;
262
+ std:: pin:: Pin :: from ( future) . await ;
263
+ }
234
264
235
265
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
236
266
// Discarding actor meshes that have been individually stopped will result in an expected error
@@ -488,3 +518,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
488
518
hyperactor_mod. add_class :: < PyProcEvent > ( ) ?;
489
519
Ok ( ( ) )
490
520
}
521
+
522
+ #[ cfg( test) ]
523
+ mod tests {
524
+ use std:: sync:: Arc ;
525
+ use std:: sync:: atomic:: AtomicBool ;
526
+ use std:: sync:: atomic:: AtomicU32 ;
527
+ use std:: sync:: atomic:: Ordering ;
528
+
529
+ use anyhow:: Result ;
530
+ use hyperactor_mesh:: alloc:: AllocSpec ;
531
+ use hyperactor_mesh:: alloc:: Allocator ;
532
+ use hyperactor_mesh:: alloc:: local:: LocalAllocator ;
533
+ use hyperactor_mesh:: proc_mesh:: ProcMesh ;
534
+ use ndslice:: extent;
535
+ use tokio:: sync:: Mutex ;
536
+
537
+ use super :: * ;
538
+
539
+ #[ tokio:: test]
540
+ async fn test_register_onstop_callback_single ( ) -> Result < ( ) > {
541
+ // Create a TrackedProcMesh
542
+ let alloc = LocalAllocator
543
+ . allocate ( AllocSpec {
544
+ extent : extent ! { replica = 1 } ,
545
+ constraints : Default :: default ( ) ,
546
+ } )
547
+ . await ?;
548
+
549
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
550
+
551
+ // Extract events before wrapping in TrackedProcMesh
552
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
553
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
554
+
555
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
556
+
557
+ // Create a flag to track if callback was executed
558
+ let callback_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
559
+ let callback_executed_clone = callback_executed. clone ( ) ;
560
+
561
+ // Register a callback
562
+ tracked_proc_mesh
563
+ . register_onstop_callback ( move || {
564
+ let flag = callback_executed_clone. clone ( ) ;
565
+ async move {
566
+ flag. store ( true , Ordering :: SeqCst ) ;
567
+ }
568
+ } )
569
+ . await ?;
570
+
571
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
572
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
573
+
574
+ // Call stop_mesh (this should trigger the callback)
575
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
576
+
577
+ // Verify the callback was executed
578
+ assert ! (
579
+ callback_executed. load( Ordering :: SeqCst ) ,
580
+ "Callback should have been executed"
581
+ ) ;
582
+
583
+ Ok ( ( ) )
584
+ }
585
+
586
+ #[ tokio:: test]
587
+ async fn test_register_onstop_callback_multiple ( ) -> Result < ( ) > {
588
+ // Create a TrackedProcMesh
589
+ let alloc = LocalAllocator
590
+ . allocate ( AllocSpec {
591
+ extent : extent ! { replica = 1 } ,
592
+ constraints : Default :: default ( ) ,
593
+ } )
594
+ . await ?;
595
+
596
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
597
+
598
+ // Extract events before wrapping in TrackedProcMesh
599
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
600
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
601
+
602
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
603
+
604
+ // Create counters to track callback executions
605
+ let callback_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
606
+ let execution_order = Arc :: new ( Mutex :: new ( Vec :: < u32 > :: new ( ) ) ) ;
607
+
608
+ // Register multiple callbacks
609
+ for i in 1 ..=3 {
610
+ let count = callback_count. clone ( ) ;
611
+ let order = execution_order. clone ( ) ;
612
+ tracked_proc_mesh
613
+ . register_onstop_callback ( move || {
614
+ let count_clone = count. clone ( ) ;
615
+ let order_clone = order. clone ( ) ;
616
+ async move {
617
+ count_clone. fetch_add ( 1 , Ordering :: SeqCst ) ;
618
+ let mut order_vec = order_clone. lock ( ) . await ;
619
+ order_vec. push ( i) ;
620
+ }
621
+ } )
622
+ . await ?;
623
+ }
624
+
625
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
626
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
627
+
628
+ // Call stop_mesh (this should trigger all callbacks)
629
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
630
+
631
+ // Verify all callbacks were executed
632
+ assert_eq ! (
633
+ callback_count. load( Ordering :: SeqCst ) ,
634
+ 3 ,
635
+ "All 3 callbacks should have been executed"
636
+ ) ;
637
+
638
+ // Verify execution order (callbacks should be executed in registration order)
639
+ let order_vec = execution_order. lock ( ) . await ;
640
+ assert_eq ! (
641
+ * order_vec,
642
+ vec![ 1 , 2 , 3 ] ,
643
+ "Callbacks should be executed in registration order"
644
+ ) ;
645
+
646
+ Ok ( ( ) )
647
+ }
648
+
649
+ #[ tokio:: test]
650
+ async fn test_register_onstop_callback_error_handling ( ) -> Result < ( ) > {
651
+ // Create a TrackedProcMesh
652
+ let alloc = LocalAllocator
653
+ . allocate ( AllocSpec {
654
+ extent : extent ! { replica = 1 } ,
655
+ constraints : Default :: default ( ) ,
656
+ } )
657
+ . await ?;
658
+
659
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
660
+
661
+ // Extract events before wrapping in TrackedProcMesh
662
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
663
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
664
+
665
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
666
+
667
+ // Create flags to track callback executions
668
+ let callback1_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
669
+ let callback2_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
670
+
671
+ let callback1_executed_clone = callback1_executed. clone ( ) ;
672
+ let callback2_executed_clone = callback2_executed. clone ( ) ;
673
+
674
+ // Register a callback that panics
675
+ tracked_proc_mesh
676
+ . register_onstop_callback ( move || {
677
+ let flag = callback1_executed_clone. clone ( ) ;
678
+ async move {
679
+ flag. store ( true , Ordering :: SeqCst ) ;
680
+ // This callback completes successfully
681
+ }
682
+ } )
683
+ . await ?;
684
+
685
+ // Register another callback that should still execute even if the first one had issues
686
+ tracked_proc_mesh
687
+ . register_onstop_callback ( move || {
688
+ let flag = callback2_executed_clone. clone ( ) ;
689
+ async move {
690
+ flag. store ( true , Ordering :: SeqCst ) ;
691
+ }
692
+ } )
693
+ . await ?;
694
+
695
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
696
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
697
+
698
+ // Call stop_mesh (this should trigger both callbacks)
699
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
700
+
701
+ // Verify both callbacks were executed
702
+ assert ! (
703
+ callback1_executed. load( Ordering :: SeqCst ) ,
704
+ "First callback should have been executed"
705
+ ) ;
706
+ assert ! (
707
+ callback2_executed. load( Ordering :: SeqCst ) ,
708
+ "Second callback should have been executed"
709
+ ) ;
710
+
711
+ Ok ( ( ) )
712
+ }
713
+ }
0 commit comments