-
Notifications
You must be signed in to change notification settings - Fork 1
JIT the whole perform_step and everything that is lower-level than that #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
def test_min_step_exception(steprule): | ||
steprule.min_step = 0.1 | ||
with pytest.raises(ValueError): | ||
steprule.suggest( | ||
previous_dt=1e-1, | ||
scaled_error_estimate=1_000_000_000, | ||
local_convergence_rate=1, | ||
) | ||
|
||
@staticmethod | ||
def test_max_step_exception(steprule): | ||
steprule.max_step = 10.0 | ||
with pytest.raises(ValueError): | ||
steprule.suggest( | ||
previous_dt=9.0, | ||
scaled_error_estimate=1 / 1_000_000_000, | ||
local_convergence_rate=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this functionality is removed. I left a comment in the perform_step, because maybe we want something similar back. THe current stuff was just unjittable...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
who needs a safetynet
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar enough with all this to be sure that it's perfectly fine, but it looked ok and I'm really looking forward to trying this out right now :D Awesome work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm taking my approval back for now :/ I tried to solve a PDE with the DiagonalEK1 from this branch, and my python console just crashed. So before we merge this in we should be very sure that these solvers can be used for all the experiments!
Okay, so @nathanaelbosch your problem should be solved by updating my branch to the most recent main. It brought up a valid point, however: does the jitted while_loop always add performance gains? I tried some things and found out that the answer is a strong YES for d << 1000. For higher dimensions, it is a NO. Since <1900 is not enough, I left the original perform step in as an option, which can be accessed with solve(compile_step={True,False}). We can use "True" for small problems, and "False" for large ones. (I suppose for large problems so much happens internally in some BLAS call that the python overhead in perform_step is entirely neglectible. Maybe the jax control flow internally carries some values which make it more expensive to use in these settings. I suppose that once/if the whole solve is jitted, this discrepancy is made up for.) (I also added the script that I used for the above to |
@@ -151,12 +148,12 @@ def evaluate_ode(t, f, mp, P, e1): | |||
H = e1 @ P | |||
return z, H | |||
|
|||
def attempt_step(self, state, dt, verbose=False): | |||
@partial(jax.jit, static_argnums=(0, 3, 7, 8)) | |||
def attempt_step(self, state, dt, f, t0, tmax, y0, df, df_diagonal): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we pass all the unused stuff here, just such that we can call attempt_step(..., *ivp)
? If so, I vote in favor for a long-but-explicit alternative (i.e., pass (keyword-)args explicitly) to avoid annoying bugs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do it some other time and merge this in now? There are a few PRs in the pipeline now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure 👍🏼
The title says it all. I hope you understand the code. Everything that is a bit odd, I tried to explain in comments.