Skip to content

Add support for optimizer checkpointing #579

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 2, 2025
Merged

Add support for optimizer checkpointing #579

merged 20 commits into from
Apr 2, 2025

Conversation

bwohlberg
Copy link
Collaborator

Add support for optimizer checkpointing. Also bump supported jaxlib/jax versions and address issues in docs.

@bwohlberg bwohlberg added documentation Improvements or additions to documentation enhancement New feature or request labels Mar 5, 2025
Copy link

codecov bot commented Mar 5, 2025

Codecov Report

Attention: Patch coverage is 92.30769% with 6 lines in your changes missing coverage. Please review.

Project coverage is 93.55%. Comparing base (d4ef866) to head (2ab4a3b).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
scico/optimize/_common.py 83.33% 5 Missing ⚠️
scico/numpy/util.py 95.45% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #579      +/-   ##
==========================================
- Coverage   93.58%   93.55%   -0.04%     
==========================================
  Files          92       92              
  Lines        6123     6182      +59     
==========================================
+ Hits         5730     5783      +53     
- Misses        393      399       +6     
Flag Coverage Δ
unittests 93.55% <92.31%> (-0.04%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment on lines +173 to +177
admm0.solve()
with tempfile.TemporaryDirectory() as tmpdir:
path = os.path.join(tmpdir, "admm.npz")
admm0.save_state(path)
admm0.solve()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is admm0.solve() called twice?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sequence is as follows:

  1. Solve 5 iterations of admm0.
  2. Save admm0 state.
  3. Solve another 5 iterations of admm0.
  4. Initialize admm1 with state of admm0 after 5 iterations.
  5. Solve another 5 iterations of admm1.
  6. Compare results for admm0 and admm.

The idea is that the additional 5 iterations after saving the state allow the solutions to diverge if any component of the state is not properly saved, so that the final test can just be on the solution rather than on all state components.

@bwohlberg bwohlberg merged commit 4d04018 into main Apr 2, 2025
19 checks passed
@bwohlberg bwohlberg deleted the brendt/checkpoint branch April 2, 2025 15:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants