@@ -361,7 +361,7 @@ def __init__(
361
361
362
362
def _initialize_data_system (self ):
363
363
num_n_samples = self .config .actor_rollout_ref .rollout .n
364
- # 1. 初始化TransferQueueStorage
364
+ # 1. initialize TransferQueueStorage
365
365
total_storage_size = self .config .data .train_batch_size * self .config .trainer .num_global_batch * num_n_samples
366
366
self .data_system_storage_units = {}
367
367
storage_placement_group = get_placement_group (self .config .trainer .num_data_storage_units , num_cpus_per_actor = 1 )
@@ -373,8 +373,9 @@ def _initialize_data_system(self):
373
373
self .data_system_storage_units [storage_unit_rank ] = storage_node
374
374
logging .info (f"TransferQueueStorageSimpleUnit #{ storage_unit_rank } has been created." )
375
375
376
- # 2. 初始化TransferQueueController
377
- # 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务
376
+ # 2. initialize TransferQueueController
377
+ # we support inilialize multiple controller instances for large-scale scenario. Please allocate exactly
378
+ # one controller for a single WorkerGroup.
378
379
self .data_system_controllers = {}
379
380
controller_placement_group = get_placement_group (self .config .trainer .num_data_controllers , num_cpus_per_actor = 1 )
380
381
for controller_rank in range (self .config .trainer .num_data_controllers ):
@@ -388,8 +389,7 @@ def _initialize_data_system(self):
388
389
)
389
390
logging .info (f"TransferQueueController #{ controller_rank } has been created." )
390
391
391
- # 3. 将Controller注册至各个Storage
392
- # 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输
392
+ # 3. register controller & storage
393
393
self .data_system_controller_infos = process_zmq_server_info (self .data_system_controllers )
394
394
self .data_system_storage_unit_infos = process_zmq_server_info (self .data_system_storage_units )
395
395
@@ -400,11 +400,11 @@ def _initialize_data_system(self):
400
400
]
401
401
)
402
402
403
- # 4. 创建Client
403
+ # 4. create client
404
+ # each client should be allocated to exactly one controller
404
405
self .data_system_client = AsyncTransferQueueClient (
405
406
client_id = "Trainer" ,
406
407
controller_infos = self .data_system_controller_infos [0 ],
407
- # TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller
408
408
storage_infos = self .data_system_storage_unit_infos ,
409
409
)
410
410
@@ -1472,7 +1472,9 @@ def fit(self):
1472
1472
log_rollout_meta .reorder (balanced_idx )
1473
1473
self ._log_rollout_data (log_rollout_meta , reward_extra_infos_dict , timing_raw , rollout_data_dir )
1474
1474
1475
- # validate
1475
+ # TODO: clear meta after iteration
1476
+
1477
+ # TODO: validate
1476
1478
if (
1477
1479
self .val_reward_fn is not None
1478
1480
and self .config .trainer .test_freq > 0
0 commit comments