Skip to content

Commit c4eca0e

Browse files
committed
fix broadcast
1 parent 0277cff commit c4eca0e

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

torchrl/modules/llm/backends/vllm/vllm_async.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def init_weight_update_group(
116116
from vllm.distributed.parallel_state import get_world_group
117117

118118
torchrl_logger.info(f"=> in {type(self).__name__}.init_weight_update_group")
119+
if self.model_update_group is not None:
120+
torchrl_logger.info("Model update group already initialized")
121+
return
119122

120123
# Get the local rank within the tensor parallel group
121124
tp_group = get_world_group()
@@ -1268,7 +1271,7 @@ def _update_weights_with_nccl_broadcast_simple(
12681271

12691272
with torch.cuda.device(0): # Ensure we're on the correct CUDA device
12701273
for i, (name, weight) in enumerate(gpu_weights.items()):
1271-
torchrl_logger.info(
1274+
torchrl_logger.debug(
12721275
f"Processing weight {i+1}/{len(gpu_weights)}: {name} {weight.shape}"
12731276
)
12741277

@@ -1279,11 +1282,11 @@ def _update_weights_with_nccl_broadcast_simple(
12791282
remotes.extend(worker_remotes)
12801283

12811284
# Step 2: Immediately broadcast this weight from master (rank 0)
1282-
torchrl_logger.info(f"Broadcasting weight {name} from master...")
1285+
torchrl_logger.debug(f"Broadcasting weight {name} from master...")
12831286
self._nccl_master_group.broadcast(
12841287
weight, src=0, stream=torch.cuda.current_stream()
12851288
)
1286-
torchrl_logger.info(f"Master broadcast completed for {name}")
1289+
torchrl_logger.debug(f"Master broadcast completed for {name}")
12871290

12881291
# Wait for all workers to complete all weight updates
12891292
torchrl_logger.info("Waiting for all worker updates to complete...")

0 commit comments

Comments
 (0)