Skip to content

Commit cc7c7da

Browse files
committed
apply review suggestions
1 parent 7c927e5 commit cc7c7da

File tree

7 files changed

+21
-136
lines changed

7 files changed

+21
-136
lines changed

recipe/transfer_queue/fsdp_workers.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

recipe/transfer_queue/main_ppo.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,7 @@ def add_actor_rollout_worker(self, config):
113113
from verl.single_controller.ray import RayWorkerGroup
114114

115115
if config.actor_rollout_ref.actor.strategy in {"fsdp", "fsdp2"}:
116-
from .fsdp_workers import (
117-
ActorRolloutRefWorker,
118-
AsyncActorRolloutRefWorker,
119-
)
116+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
120117

121118
actor_rollout_cls = (
122119
AsyncActorRolloutRefWorker
@@ -126,10 +123,7 @@ def add_actor_rollout_worker(self, config):
126123
ray_worker_group_cls = RayWorkerGroup
127124

128125
elif config.actor_rollout_ref.actor.strategy == "megatron":
129-
from .megatron_workers import (
130-
ActorRolloutRefWorker,
131-
AsyncActorRolloutRefWorker,
132-
)
126+
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
133127

134128
actor_rollout_cls = (
135129
AsyncActorRolloutRefWorker
@@ -152,7 +146,7 @@ def add_critic_worker(self, config):
152146
if config.critic.strategy in {"fsdp", "fsdp2"}:
153147
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
154148
if use_legacy_worker_impl in ["auto", "enable"]:
155-
from .fsdp_workers import CriticWorker
149+
from verl.workers.fsdp_workers import CriticWorker
156150
elif use_legacy_worker_impl == "disable":
157151
from verl.workers.roles import CriticWorker
158152

@@ -161,7 +155,7 @@ def add_critic_worker(self, config):
161155
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")
162156

163157
elif config.critic.strategy == "megatron":
164-
from .megatron_workers import CriticWorker
158+
from verl.workers.megatron_workers import CriticWorker
165159

166160
else:
167161
raise NotImplementedError
@@ -203,9 +197,9 @@ def add_reward_model_worker(self, config):
203197
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
204198
if use_legacy_worker_impl in ["auto", "enable"]:
205199
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
206-
from .fsdp_workers import RewardModelWorker
200+
from verl.workers.fsdp_workers import RewardModelWorker
207201
elif config.reward_model.strategy == "megatron":
208-
from .megatron_workers import RewardModelWorker
202+
from verl.workers.megatron_workers import RewardModelWorker
209203
else:
210204
raise NotImplementedError
211205
elif use_legacy_worker_impl == "disable":

recipe/transfer_queue/megatron_workers.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

requirements.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,4 @@ uvicorn
2323
fastapi
2424
latex2sympy2_extended
2525
math_verify
26-
tensorboard
27-
git+https://github.com/TransferQueue/TransferQueue.git@a2ddb30
26+
tensorboard

requirements_transferqueue.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# requirements.txt records the full set of dependencies for development
2+
git+https://github.com/TransferQueue/TransferQueue.git@a2ddb30

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
"wandb",
4343
"packaging>=20.0",
4444
"tensorboard",
45-
"TransferQueue @ git+https://github.com/TransferQueue/TransferQueue.git@a2ddb30",
4645
]
4746

4847
TEST_REQUIRES = ["pytest", "pre-commit", "py-spy", "pytest-asyncio"]
@@ -58,6 +57,7 @@
5857
]
5958
TRL_REQUIRES = ["trl<=0.9.6"]
6059
MCORE_REQUIRES = ["mbridge"]
60+
TRANSFERQUEUE_REQUIRES = ["TransferQueue @ git+https://github.com/TransferQueue/TransferQueue.git@a2ddb30"]
6161

6262
extras_require = {
6363
"test": TEST_REQUIRES,
@@ -69,6 +69,7 @@
6969
"sglang": SGLANG_REQUIRES,
7070
"trl": TRL_REQUIRES,
7171
"mcore": MCORE_REQUIRES,
72+
"transferqueue": TRANSFERQUEUE_REQUIRES,
7273
}
7374

7475

verl/single_controller/base/worker.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,16 @@ def _query_collect_info(self, mesh_name: str):
129129
"""
130130
assert mesh_name in self.__collect_dp_rank, f"{mesh_name} is not registered in {self.__class__.__name__}"
131131
return self.__collect_dp_rank[mesh_name]
132+
133+
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
134+
def create_transferqueue_client(self, controller_infos, storage_infos, role="train"):
135+
from verl.utils.transferqueue_utils import create_transferqueue_client
136+
137+
create_transferqueue_client(
138+
client_id=f"{role}_worker_{self.rank}",
139+
controller_infos=controller_infos,
140+
storage_infos=storage_infos,
141+
)
132142

133143
@classmethod
134144
def env_keys(cls):

0 commit comments

Comments
 (0)