Skip to content

Conversation

neel04
Copy link
Contributor

@neel04 neel04 commented Jul 4, 2025

Jax doesn't like multi-threading. When using the eval harness during some jax code, the multithreading causes deadlocks blocking everything.

This PR just adds an optional flag to disable multi-threading, specifically in the stderr calculation in metrics.py since that seems to be the main culprit.

The idea is that JAX users can simply disable pass this flag and disable all MP. This will also have to be documented somewhere, I have no idea where.

@CLAassistant
Copy link

CLAassistant commented Jul 4, 2025

CLA assistant check
All committers have signed the CLA.

@baberabb
Copy link
Contributor

baberabb commented Jul 4, 2025

Hi! Thanks for the PR! This looks good overall. I made some slight changes to keep parity with mp-bootstrapping (seed and return type) and took the opportunity to add some docs while I was at it. Good to merge!

@baberabb baberabb merged commit 71d0289 into EleutherAI:main Jul 4, 2025
6 checks passed
@baberabb
Copy link
Contributor

baberabb commented Jul 4, 2025

Just though of: could be faster with a list comp, but not sure of the memory implications.

@neel04 neel04 deleted the bugfix/mp branch July 6, 2025 11:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants