@@ -116,6 +116,9 @@ def init_weight_update_group(
116
116
from vllm .distributed .parallel_state import get_world_group
117
117
118
118
torchrl_logger .info (f"=> in { type (self ).__name__ } .init_weight_update_group" )
119
+ if self .model_update_group is not None :
120
+ torchrl_logger .info ("Model update group already initialized" )
121
+ return
119
122
120
123
# Get the local rank within the tensor parallel group
121
124
tp_group = get_world_group ()
@@ -1268,7 +1271,7 @@ def _update_weights_with_nccl_broadcast_simple(
1268
1271
1269
1272
with torch .cuda .device (0 ): # Ensure we're on the correct CUDA device
1270
1273
for i , (name , weight ) in enumerate (gpu_weights .items ()):
1271
- torchrl_logger .info (
1274
+ torchrl_logger .debug (
1272
1275
f"Processing weight { i + 1 } /{ len (gpu_weights )} : { name } { weight .shape } "
1273
1276
)
1274
1277
@@ -1279,11 +1282,11 @@ def _update_weights_with_nccl_broadcast_simple(
1279
1282
remotes .extend (worker_remotes )
1280
1283
1281
1284
# Step 2: Immediately broadcast this weight from master (rank 0)
1282
- torchrl_logger .info (f"Broadcasting weight { name } from master..." )
1285
+ torchrl_logger .debug (f"Broadcasting weight { name } from master..." )
1283
1286
self ._nccl_master_group .broadcast (
1284
1287
weight , src = 0 , stream = torch .cuda .current_stream ()
1285
1288
)
1286
- torchrl_logger .info (f"Master broadcast completed for { name } " )
1289
+ torchrl_logger .debug (f"Master broadcast completed for { name } " )
1287
1290
1288
1291
# Wait for all workers to complete all weight updates
1289
1292
torchrl_logger .info ("Waiting for all worker updates to complete..." )
0 commit comments