Skip to content

Conversation

@ali-ramadhan
Copy link
Member

@ali-ramadhan ali-ramadhan commented Oct 30, 2025

This PR refactors how the Checkpointer works by now checkpointing simulations, rather than just models. This is needed as the simulations (+ output writers, callbacks, etc.) all contain crucial information needed to properly restore/pickup a simulation and continue time stepping.

Basic design idea:

  • We now have two new functions: prognostic_state(obj) which returns a named tuple corresponding to the prognostic state of obj and restore_prognostic_state!(obj, state) which restores obj based on information contained in state (which is a named tuple and is read from a checkpoint file).
  • Objects are checkpointed recursively by serializing prognostic information to the JLD2 checkpoint file.
  • The goal is for checkpointing to be flexible enough that we can very easily use it for different types of simulations, e.g. coupled simulations in ClimaOcean.jl by just defining prognostic_state and restore_prognostic_state!.

Right now I've only implemented proper checkpointing for non-hydrostatic model but it looks like it'll be straightforward to do it for hydrostatic and shallow water models. I'm working on adding comprehensive testing too.

Will continue working on this PR, but any feedback is very welcome!

Resolves #1249
Resolves #2866
Resolves #3670
Resolves #3845
Resolves #4516
Resolves #4857


Rhetorical aside

In general, the checkpointer is assuming that the simulation setup is the same. So only prognostic state information that changes will be checkpointed (e.g. field data, TimeInterval.actuations, etc.). The approach I have been taking (based on #4857) is to only checkpoint the prognostic state.

Should we operate under this assumption? I think so because not doing so can lead to a lot of undefined behavior. The checkpointer should not be responsible for checking that you set up the same simulation as the one that was checkpointed.

For example, take the SpecifiedTimes schedule. It has two properties times and previous_actuation. Since previous_actuation changes as the simulation runs, only previous_actuation needs to be checkpointed.

This leads to the possibility of the user changing times then picking up previous_actuation which can lead to undefined behavior. I think this is fine, because the checkpointer only works assuming you set up the same simulation as the one that was checkpointed.

Checkpointing both times and previous_actuation allows us to check that times is the same when restoring. But I don't think this is the checkpointer's responsibility.

@ali-ramadhan
Copy link
Member Author

ali-ramadhan commented Nov 13, 2025

Looks like tests will all pass 🎉 I'll start testing the checkpointing of increasingly complex simulations while iterating on the design! This way we'll be able to weed out most bugs and issues.

@ali-ramadhan ali-ramadhan changed the title Checkpointing simulations (v0.103.0) Checkpointing simulations Dec 12, 2025
@ali-ramadhan
Copy link
Member Author

ali-ramadhan commented Dec 12, 2025

Ok I think this PR is finally ready for review! I went through and added checkpointing support + tests to different time steppers, models, schedules, turbulence closures, etc. but let me know if you feel like I missed something.

This PR has a few breaking changes so we should tag v0.103.0 once it's merged:

  1. The Checkpointer constructor no longer has the properties keyword. prognostic_state decides exactly what gets checkpointed.
  2. set!(model, filepath::String) is now set!(simulation, pickup) where pickup is Bool, Int, or String.
  3. Checkpoint files are not backwards compatible. You will not be able to restore from an old (pre v0.103.0) checkpoint file. Although this has generally always been true.

The PR is getting quite large and distributed CI is not working, so I'll test checkpointing support for distributed models in a subsequent PR.

@tomchor
Copy link
Collaborator

tomchor commented Dec 12, 2025

This leads to the possibility of the user changing times then picking up previous_actuation which can lead to undefined behavior. I think this is fine, because the checkpointer only works assuming you set up the same simulation as the one that was checkpointed.

Checkpointing both times and previous_actuation allows us to check that times is the same when restoring. But I don't think this is the checkpointer's responsibility.

There's been a lot of progress since this PR was open, so I'm wondering if this is still true. I don't oppose this, but I also don't oppose saving non-prognostic information when easy and throwing a warning to the user if we flag differences there. Mistakes happen and I think this sort of user-friendliness can avoid some headache. Again, only when it's to do so though. I definitely don't think Checkpointer needs to check every single aspect of a Simulation to see if everything matches up.

Comment on lines +72 to +75
```julia
set!(simulation, filepath) # restore from specific file (no Checkpointer required)
set!(simulation, true) # restore from latest checkpoint (requires Checkpointer)
set!(simulation, iteration) # restore from specific iteration (requires Checkpointer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that set!(simulation, true) looks vague. As in if you saw that line in a script randomly you wouldn't know what it does. So I might add kwargs so it instead looks like

set!(simulation; filepath="spinup.jld2")
set!(simulation; use_latest_checkpoint=true)
set!(simulation; iteration=12345)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only qualm is that "false" means nothing.

An altnerative is a symbol, eg

set!(simulation, :latest_checkpoint)

i do like the kwarg for the other options.

Comment on lines +80 to +89
For cluster jobs with time limits, use `WallTimeInterval` to checkpoint based on elapsed
wall-clock time rather than simulation time or iterations:

```julia
# Checkpoint every 30 minutes of wall-clock time
Checkpointer(model, schedule=WallTimeInterval(30minute), prefix="checkpoint")
```

This ensures checkpoints are saved regularly even if individual time steps vary significantly.

Copy link
Collaborator

@tomchor tomchor Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may be a job for another PR, but should we add option to write a Checkpoint when the wall_time_limit is exceeded? It might be super useful exactly for this sort of cluster situation such that we don't run any time-step twice regardless of how many pickups a simulation ends up needing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely a good use case. I was thinking we can do something more general and checkpoint after a simulation stops running (whether it is done or not)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me!

Comment on lines 1591 to 1595
return (
schedule = prognostic_state(writer.schedule),
part = writer.part,
windowed_time_averages = isempty(wta_outputs) ? nothing : wta_outputs,
)
Copy link
Collaborator

@tomchor tomchor Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we wanna enforce the formatting we usually prefer in Oceananigans? Namely:

Suggested change
return (
schedule = prognostic_state(writer.schedule),
part = writer.part,
windowed_time_averages = isempty(wta_outputs) ? nothing : wta_outputs,
)
return (schedule = prognostic_state(writer.schedule),
part = writer.part,
windowed_time_averages = isempty(wta_outputs) ? nothing : wta_outputs)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer keeping it consistent (and changing it globally if we want) but it is not the most critical thing in the world

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can be a bit sloppy with coding style. I'll make a pass and use a coding style that is consistent with the rest of Oceananigans.jl.

Comment on lines +403 to +409
return (
η = prognostic_state(fs.η),
barotropic_velocities = prognostic_state(fs.barotropic_velocities),
filtered_state = prognostic_state(fs.filtered_state),
timestepper = prognostic_state(fs.timestepper),
)
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here about formatting

Comment on lines +174 to +183
return (
ηᵐ = prognostic_state(ts.ηᵐ),
ηᵐ⁻¹ = prognostic_state(ts.ηᵐ⁻¹),
ηᵐ⁻² = prognostic_state(ts.ηᵐ⁻²),
Uᵐ⁻¹ = prognostic_state(ts.Uᵐ⁻¹),
Uᵐ⁻² = prognostic_state(ts.Uᵐ⁻²),
Vᵐ⁻¹ = prognostic_state(ts.Vᵐ⁻¹),
Vᵐ⁻² = prognostic_state(ts.Vᵐ⁻²),
)
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here (and other places below as well)

Comment on lines +340 to +372

#####
##### Checkpointing the JLD2Writer
#####

function prognostic_state(writer::JLD2Writer)
# Only checkpoint WindowedTimeAverage outputs (which have accumulated state).
wta_outputs = NamedTuple(name => prognostic_state(output)
for (name, output) in pairs(writer.outputs)
if output isa WindowedTimeAverage)

return (
schedule = prognostic_state(writer.schedule),
part = writer.part,
windowed_time_averages = isempty(wta_outputs) ? nothing : wta_outputs,
)
end

function restore_prognostic_state!(writer::JLD2Writer, state)
restore_prognostic_state!(writer.schedule, state.schedule)
writer.part = state.part

# Restore WindowedTimeAverage outputs if present
if hasproperty(state, :windowed_time_averages) && !isnothing(state.windowed_time_averages)
for (name, wta_state) in pairs(state.windowed_time_averages)
if haskey(writer.outputs, name) && writer.outputs[name] isa WindowedTimeAverage
restore_prognostic_state!(writer.outputs[name], wta_state)
end
end
end

return writer
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this code exactly the same for all OutputWriters? If so, can we write it only once? In the future we might also have other supported formats (like zarr) and it'd be good to unify as much code as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if there are small differences, we can open an issue to smooth out those wrinkles between the two --- effectively creating a "convention for output writers" with the benefits mentioned by @tomchor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I should be able to combine prognostic_state and restore_prognostic_state! for multiple writers.

@glwagner
Copy link
Member

There's been a lot of progress since this PR was open, so I'm wondering if this is still true. I don't oppose this, but I also don't oppose saving non-prognostic information when easy and throwing a warning to the user if we flag differences there. Mistakes happen and I think this sort of user-friendliness can avoid some headache. Again, only when it's to do so though. I definitely don't think Checkpointer needs to check every single aspect of a Simulation to see if everything matches up.

@tomchor can you specify what you mean by "this" in your comment?

Possibly unrelated, but I'm wondering if we think there would be value in being as strict as humanly possible about the accuracy of picking up from a checkpoint. Otherwise critical operational applications will have to implement their own testing system because they could not rely on the Oceananigans in-built system for it. If there is a need to relax strictness for the purpose of productivity, we might have a checkpointer "mode" (strict=true, false)

Copy link
Collaborator

@tomchor tomchor Dec 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work with testing! Pretty comprehensive. My only suggestion here is maybe instead of write two functions like test_XYZ_hydrostatic() and test_XYZ_nonhodrostatic(), should we write one function called test_XYZ() and pass the model as a argument? There might be some if/else statements in there, but I feel like this would be an easier test suite to keep track of and maintain in the long run, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have combined some but definitely seeing room to combine more. I agree it'll make it easier to maintain in the future.

@tomchor
Copy link
Collaborator

tomchor commented Dec 12, 2025

There's been a lot of progress since this PR was open, so I'm wondering if this is still true. I don't oppose this, but I also don't oppose saving non-prognostic information when easy and throwing a warning to the user if we flag differences there. Mistakes happen and I think this sort of user-friendliness can avoid some headache. Again, only when it's to do so though. I definitely don't think Checkpointer needs to check every single aspect of a Simulation to see if everything matches up.

@tomchor can you specify what you mean by "this" in your comment?

Possibly unrelated, but I'm wondering if we think there would be value in being as strict as humanly possible about the accuracy of picking up from a checkpoint. Otherwise critical operational applications will have to implement their own testing system because they could not rely on the Oceananigans in-built system for it. If there is a need to relax strictness for the purpose of productivity, we might have a checkpointer "mode" (strict=true, false)

Sorry, by this I meant the general idea in the sentences I quoted (maybe I didn't quote the best ones!). Mainly that Checkpointer should assume that the simulation setup is the same and only keep track of prognostic state information. So, to repeat @ali-ramadhan's example:

For example, take the SpecifiedTimes schedule. It has two properties times and previous_actuation. Since previous_actuation changes as the simulation runs, only previous_actuation needs to be checkpointed.

This leads to the possibility of the user changing times then picking up previous_actuation which can lead to undefined behavior. I think this is fine, because the checkpointer only works assuming you set up the same simulation as the one that was checkpointed.

Checkpointing both times and previous_actuation allows us to check that times is the same when restoring. But I don't think this is the checkpointer's responsibility.

I'm not opposed to this, but I wonder if it's useful to save times and throw an error if the checkpointer realizes that this property changed (which usually is due to human error). This is ofc just one example of something simple that could be checked.

@navidcy
Copy link
Member

navidcy commented Dec 12, 2025

The PR is getting quite large and distributed CI is not working, so I'll test checkpointing support for distributed models in a subsequent PR.

There was always a distributed test failing at #4005 and never understood why...

@glwagner
Copy link
Member

I will review again once @tomchor comments have been addressed but overall, I think it is excellent and very close. Most of my thoughts were 1:1 overlap with @tomchor's comments.

ali-ramadhan and others added 3 commits December 12, 2025 11:57
@ali-ramadhan
Copy link
Member Author

ali-ramadhan commented Dec 12, 2025

Thanks for the reviews! I made a TODO list to tackle and I'll request another round of reviews. Please feel free to add to it.

  • Add kwargs to set!(simulation, ...).
  • Checkpoint after a simulation stops running (whether it is done or not).
  • Use a coding style that is consistent with the rest of Oceananigans.jl.
  • Combine prognostic_state and restore_prognostic_state! for multiple writers.
  • Combine test functions where possible, e.g. test_XYZ_hydrostatic() and test_XYZ_nonhodrostatic() -> test_XYZ().
  • Maybe tackle Restarting simulation with different halo size #3206 if it's easy (to make @simone-silvestri happy).

Possibly unrelated, but I'm wondering if we think there would be value in being as strict as humanly possible about the accuracy of picking up from a checkpoint.

@glwagner Are you suggesting that we checkpoint the diagnostic state so it can be used to make sure the simulation was restored properly? Or are you suggesting checkpointing things like TimeInterval.interval (not diagnostic nor prognostic?) to make sure it is the same when restoring from checkpoint?

There was always a distributed test failing at #4005 and never understood why...

@navidcy Right now it seems like distributed CI even fails to initialize here. I tried looking through Buildkite runs from #4005 and couldn't find a failure with an actual error. Maybe I wouldn't think too much about it until we can run distributed checkpointing tests with/after this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

5 participants