Skip to content

Commit 54fdcfd

Browse files
committed
some fixs
Signed-off-by: Lehui Liu <[email protected]>
1 parent b865627 commit 54fdcfd

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

python/ray/train/v2/jax/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
@dataclass
2121
class JaxConfig(BackendConfig):
2222
use_tpu: bool = False
23+
use_gpu: bool = False
2324

2425
@property
2526
def backend_cls(self):
@@ -108,6 +109,7 @@ def on_start(self, worker_group: WorkerGroup, backend_config: JaxConfig):
108109
master_addr_with_port=master_addr_with_port,
109110
num_workers=len(worker_group),
110111
index=i,
112+
resources_per_worker=worker_group.get_resources_per_worker()
111113
)
112114
)
113115
ray.get(setup_futures)

python/ray/train/v2/jax/jax_trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def __init__(
132132
if not jax_config:
133133
jax_config = JaxConfig(
134134
use_tpu=scaling_config.use_tpu,
135+
use_gpu=scaling_config.use_gpu,
135136
)
136137
super(JaxTrainer, self).__init__(
137138
train_loop_per_worker=train_loop_per_worker,

0 commit comments

Comments
 (0)