Skip to content

JIT the whole perform_step and everything that is lower-level than that #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Sep 21, 2021

Conversation

pnkraemer
Copy link
Owner

The title says it all. I hope you understand the code. Everything that is a bit odd, I tried to explain in comments.

Comment on lines -127 to -143
def test_min_step_exception(steprule):
steprule.min_step = 0.1
with pytest.raises(ValueError):
steprule.suggest(
previous_dt=1e-1,
scaled_error_estimate=1_000_000_000,
local_convergence_rate=1,
)

@staticmethod
def test_max_step_exception(steprule):
steprule.max_step = 10.0
with pytest.raises(ValueError):
steprule.suggest(
previous_dt=9.0,
scaled_error_estimate=1 / 1_000_000_000,
local_convergence_rate=1,
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this functionality is removed. I left a comment in the perform_step, because maybe we want something similar back. THe current stuff was just unjittable...

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

who needs a safetynet

Copy link
Collaborator

@nathanaelbosch nathanaelbosch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar enough with all this to be sure that it's perfectly fine, but it looked ok and I'm really looking forward to trying this out right now :D Awesome work!

@nathanaelbosch nathanaelbosch self-requested a review September 18, 2021 15:50
Copy link
Collaborator

@nathanaelbosch nathanaelbosch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm taking my approval back for now :/ I tried to solve a PDE with the DiagonalEK1 from this branch, and my python console just crashed. So before we merge this in we should be very sure that these solvers can be used for all the experiments!

@pnkraemer
Copy link
Owner Author

pnkraemer commented Sep 19, 2021

Okay, so @nathanaelbosch your problem should be solved by updating my branch to the most recent main.

It brought up a valid point, however: does the jitted while_loop always add performance gains? I tried some things and found out that the answer is a strong YES for d << 1000. For higher dimensions, it is a NO. Since <1900 is not enough, I left the original perform step in as an option, which can be accessed with solve(compile_step={True,False}). We can use "True" for small problems, and "False" for large ones.
Since now every attempt_step is jitted anyway, even with perform_step=False, the performance gains should be fairly significant for not-increcibly-large problems (for those, I am confident it will not ruin gains though).

(I suppose for large problems so much happens internally in some BLAS call that the python overhead in perform_step is entirely neglectible. Maybe the jax control flow internally carries some values which make it more expensive to use in these settings. I suppose that once/if the whole solve is jitted, this discrepancy is made up for.)

(I also added the script that I used for the above to timings/. I am aware that we might want to get rid of the folder soon, but until then, it might be useful there)

@@ -151,12 +148,12 @@ def evaluate_ode(t, f, mp, P, e1):
H = e1 @ P
return z, H

def attempt_step(self, state, dt, verbose=False):
@partial(jax.jit, static_argnums=(0, 3, 7, 8))
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we pass all the unused stuff here, just such that we can call attempt_step(..., *ivp)? If so, I vote in favor for a long-but-explicit alternative (i.e., pass (keyword-)args explicitly) to avoid annoying bugs

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do it some other time and merge this in now? There are a few PRs in the pipeline now...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure 👍🏼

@pnkraemer pnkraemer merged commit 69ca29b into main Sep 21, 2021
@pnkraemer pnkraemer deleted the jit-perform-step branch September 21, 2021 19:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants