Skip to content

Commit c6cecc2

Browse files
Delaunaypierre.delaunay
andauthored
Update to Pytorch 2.8 + cu129, update docker containers to 24.04 (#362)
* Pytorch 2.8 + cu129 * Recombine jax and torch env * py3.12_torch2.8+cu129 (#364) Co-authored-by: pierre.delaunay <[email protected]> * Do not init process group on prepare step * Add Distributed env variable on prepare when required * Udpate purejaxrl to use the latest distrax + tfp-nightly * ignore tensorflow-probability * Pin Dependencies (#366) * ignore tensorflow-probability * Pin Dependencies --------- Co-authored-by: pierre.delaunay <[email protected]> * replace tree_map by tree.map * update container to match current cuda version * Update dockerfile * Add a timer for rsync and add concept of job runner pipeline * Handle rerun of jobs with dependencies * make the server try its best even if the cluster is down --------- Co-authored-by: pierre.delaunay <[email protected]>
1 parent e5538b2 commit c6cecc2

37 files changed

+1199
-766
lines changed

.pin/constraints-cuda-torch.txt

Lines changed: 78 additions & 57 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benchmarks/brax/requirements.cuda.txt

Lines changed: 37 additions & 32 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

benchmarks/cleanrl_jax/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,8 @@ def convert_data(x: jnp.ndarray):
393393
x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:])
394394
return x
395395

396-
flatten_storage = jax.tree_map(flatten, storage)
397-
shuffled_storage = jax.tree_map(convert_data, flatten_storage)
396+
flatten_storage = jax.tree.map(flatten, storage)
397+
shuffled_storage = jax.tree.map(convert_data, flatten_storage)
398398

399399
def update_minibatch(agent_state, minibatch):
400400
(loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn(

0 commit comments

Comments
 (0)