@@ -37,6 +37,8 @@ use pyo3::types::PyType;
37
37
use tokio:: sync:: Mutex ;
38
38
use tokio:: sync:: mpsc;
39
39
40
+ type OnStopCallback = Box < dyn FnOnce ( ) -> Box < dyn std:: future:: Future < Output = ( ) > + Send > + Send > ;
41
+
40
42
use crate :: actor_mesh:: PythonActorMesh ;
41
43
use crate :: alloc:: PyAlloc ;
42
44
use crate :: mailbox:: PyMailbox ;
@@ -51,6 +53,7 @@ pub struct TrackedProcMesh {
51
53
inner : SharedCellRef < ProcMesh > ,
52
54
cell : SharedCell < ProcMesh > ,
53
55
children : SharedCellPool ,
56
+ onstop_callbacks : Arc < Mutex < Vec < OnStopCallback > > > ,
54
57
}
55
58
56
59
impl Debug for TrackedProcMesh {
@@ -73,6 +76,7 @@ impl From<ProcMesh> for TrackedProcMesh {
73
76
inner,
74
77
cell,
75
78
children : SharedCellPool :: new ( ) ,
79
+ onstop_callbacks : Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ,
76
80
}
77
81
}
78
82
}
@@ -103,8 +107,25 @@ impl TrackedProcMesh {
103
107
self . inner . client_proc ( )
104
108
}
105
109
106
- pub fn into_inner ( self ) -> ( SharedCell < ProcMesh > , SharedCellPool ) {
107
- ( self . cell , self . children )
110
+ pub fn into_inner (
111
+ self ,
112
+ ) -> (
113
+ SharedCell < ProcMesh > ,
114
+ SharedCellPool ,
115
+ Arc < Mutex < Vec < OnStopCallback > > > ,
116
+ ) {
117
+ ( self . cell , self . children , self . onstop_callbacks )
118
+ }
119
+
120
+ /// Register a callback to be called when this TrackedProcMesh is stopped
121
+ pub async fn register_onstop_callback < F , Fut > ( & self , callback : F ) -> Result < ( ) , anyhow:: Error >
122
+ where
123
+ F : FnOnce ( ) -> Fut + Send + ' static ,
124
+ Fut : std:: future:: Future < Output = ( ) > + Send + ' static ,
125
+ {
126
+ let mut callbacks = self . onstop_callbacks . lock ( ) . await ;
127
+ callbacks. push ( Box :: new ( || Box :: new ( callback ( ) ) ) ) ;
128
+ Ok ( ( ) )
108
129
}
109
130
}
110
131
@@ -226,7 +247,17 @@ impl PyProcMesh {
226
247
let tracked_proc_mesh = inner. take ( ) . await . map_err ( |e| {
227
248
PyRuntimeError :: new_err ( format ! ( "`ProcMesh` has already been stopped: {}" , e) )
228
249
} ) ?;
229
- let ( proc_mesh, children) = tracked_proc_mesh. into_inner ( ) ;
250
+ let ( proc_mesh, children, drop_callbacks) = tracked_proc_mesh. into_inner ( ) ;
251
+
252
+ // Call all registered drop callbacks before stopping
253
+ let mut callbacks = drop_callbacks. lock ( ) . await ;
254
+ let callbacks_to_call = callbacks. drain ( ..) . collect :: < Vec < _ > > ( ) ;
255
+ drop ( callbacks) ; // Release the lock
256
+
257
+ for callback in callbacks_to_call {
258
+ let future = callback ( ) ;
259
+ std:: pin:: Pin :: from ( future) . await ;
260
+ }
230
261
231
262
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
232
263
children. discard_all ( ) . await ?;
@@ -440,3 +471,196 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
440
471
hyperactor_mod. add_class :: < PyProcEvent > ( ) ?;
441
472
Ok ( ( ) )
442
473
}
474
+
475
+ #[ cfg( test) ]
476
+ mod tests {
477
+ use std:: sync:: Arc ;
478
+ use std:: sync:: atomic:: AtomicBool ;
479
+ use std:: sync:: atomic:: AtomicU32 ;
480
+ use std:: sync:: atomic:: Ordering ;
481
+
482
+ use anyhow:: Result ;
483
+ use hyperactor_mesh:: alloc:: AllocSpec ;
484
+ use hyperactor_mesh:: alloc:: Allocator ;
485
+ use hyperactor_mesh:: alloc:: local:: LocalAllocator ;
486
+ use hyperactor_mesh:: proc_mesh:: ProcMesh ;
487
+ use ndslice:: extent;
488
+ use tokio:: sync:: Mutex ;
489
+
490
+ use super :: * ;
491
+
492
+ #[ tokio:: test]
493
+ async fn test_register_onstop_callback_single ( ) -> Result < ( ) > {
494
+ // Create a TrackedProcMesh
495
+ let alloc = LocalAllocator
496
+ . allocate ( AllocSpec {
497
+ extent : extent ! { replica = 1 } ,
498
+ constraints : Default :: default ( ) ,
499
+ } )
500
+ . await ?;
501
+
502
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
503
+
504
+ // Extract events before wrapping in TrackedProcMesh
505
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
506
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
507
+
508
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
509
+
510
+ // Create a flag to track if callback was executed
511
+ let callback_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
512
+ let callback_executed_clone = callback_executed. clone ( ) ;
513
+
514
+ // Register a callback
515
+ tracked_proc_mesh
516
+ . register_onstop_callback ( move || {
517
+ let flag = callback_executed_clone. clone ( ) ;
518
+ async move {
519
+ flag. store ( true , Ordering :: SeqCst ) ;
520
+ }
521
+ } )
522
+ . await ?;
523
+
524
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
525
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
526
+
527
+ // Call stop_mesh (this should trigger the callback)
528
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
529
+
530
+ // Verify the callback was executed
531
+ assert ! (
532
+ callback_executed. load( Ordering :: SeqCst ) ,
533
+ "Callback should have been executed"
534
+ ) ;
535
+
536
+ Ok ( ( ) )
537
+ }
538
+
539
+ #[ tokio:: test]
540
+ async fn test_register_onstop_callback_multiple ( ) -> Result < ( ) > {
541
+ // Create a TrackedProcMesh
542
+ let alloc = LocalAllocator
543
+ . allocate ( AllocSpec {
544
+ extent : extent ! { replica = 1 } ,
545
+ constraints : Default :: default ( ) ,
546
+ } )
547
+ . await ?;
548
+
549
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
550
+
551
+ // Extract events before wrapping in TrackedProcMesh
552
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
553
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
554
+
555
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
556
+
557
+ // Create counters to track callback executions
558
+ let callback_count = Arc :: new ( AtomicU32 :: new ( 0 ) ) ;
559
+ let execution_order = Arc :: new ( Mutex :: new ( Vec :: < u32 > :: new ( ) ) ) ;
560
+
561
+ // Register multiple callbacks
562
+ for i in 1 ..=3 {
563
+ let count = callback_count. clone ( ) ;
564
+ let order = execution_order. clone ( ) ;
565
+ tracked_proc_mesh
566
+ . register_onstop_callback ( move || {
567
+ let count_clone = count. clone ( ) ;
568
+ let order_clone = order. clone ( ) ;
569
+ async move {
570
+ count_clone. fetch_add ( 1 , Ordering :: SeqCst ) ;
571
+ let mut order_vec = order_clone. lock ( ) . await ;
572
+ order_vec. push ( i) ;
573
+ }
574
+ } )
575
+ . await ?;
576
+ }
577
+
578
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
579
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
580
+
581
+ // Call stop_mesh (this should trigger all callbacks)
582
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
583
+
584
+ // Verify all callbacks were executed
585
+ assert_eq ! (
586
+ callback_count. load( Ordering :: SeqCst ) ,
587
+ 3 ,
588
+ "All 3 callbacks should have been executed"
589
+ ) ;
590
+
591
+ // Verify execution order (callbacks should be executed in registration order)
592
+ let order_vec = execution_order. lock ( ) . await ;
593
+ assert_eq ! (
594
+ * order_vec,
595
+ vec![ 1 , 2 , 3 ] ,
596
+ "Callbacks should be executed in registration order"
597
+ ) ;
598
+
599
+ Ok ( ( ) )
600
+ }
601
+
602
+ #[ tokio:: test]
603
+ async fn test_register_onstop_callback_error_handling ( ) -> Result < ( ) > {
604
+ // Create a TrackedProcMesh
605
+ let alloc = LocalAllocator
606
+ . allocate ( AllocSpec {
607
+ extent : extent ! { replica = 1 } ,
608
+ constraints : Default :: default ( ) ,
609
+ } )
610
+ . await ?;
611
+
612
+ let mut proc_mesh = ProcMesh :: allocate ( alloc) . await ?;
613
+
614
+ // Extract events before wrapping in TrackedProcMesh
615
+ let events = proc_mesh. events ( ) . unwrap ( ) ;
616
+ let proc_events_cell = SharedCell :: from ( tokio:: sync:: Mutex :: new ( events) ) ;
617
+
618
+ let tracked_proc_mesh = TrackedProcMesh :: from ( proc_mesh) ;
619
+
620
+ // Create flags to track callback executions
621
+ let callback1_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
622
+ let callback2_executed = Arc :: new ( AtomicBool :: new ( false ) ) ;
623
+
624
+ let callback1_executed_clone = callback1_executed. clone ( ) ;
625
+ let callback2_executed_clone = callback2_executed. clone ( ) ;
626
+
627
+ // Register a callback that panics
628
+ tracked_proc_mesh
629
+ . register_onstop_callback ( move || {
630
+ let flag = callback1_executed_clone. clone ( ) ;
631
+ async move {
632
+ flag. store ( true , Ordering :: SeqCst ) ;
633
+ // This callback completes successfully
634
+ }
635
+ } )
636
+ . await ?;
637
+
638
+ // Register another callback that should still execute even if the first one had issues
639
+ tracked_proc_mesh
640
+ . register_onstop_callback ( move || {
641
+ let flag = callback2_executed_clone. clone ( ) ;
642
+ async move {
643
+ flag. store ( true , Ordering :: SeqCst ) ;
644
+ }
645
+ } )
646
+ . await ?;
647
+
648
+ // Create a SharedCell<TrackedProcMesh> for stop_mesh
649
+ let tracked_proc_mesh_cell = SharedCell :: from ( tracked_proc_mesh) ;
650
+
651
+ // Call stop_mesh (this should trigger both callbacks)
652
+ PyProcMesh :: stop_mesh ( tracked_proc_mesh_cell, proc_events_cell) . await ?;
653
+
654
+ // Verify both callbacks were executed
655
+ assert ! (
656
+ callback1_executed. load( Ordering :: SeqCst ) ,
657
+ "First callback should have been executed"
658
+ ) ;
659
+ assert ! (
660
+ callback2_executed. load( Ordering :: SeqCst ) ,
661
+ "Second callback should have been executed"
662
+ ) ;
663
+
664
+ Ok ( ( ) )
665
+ }
666
+ }
0 commit comments