Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions optimistix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
OptaxMinimiser as OptaxMinimiser,
polak_ribiere as polak_ribiere,
SteepestDescent as SteepestDescent,
Zoom as Zoom,
)


Expand Down
4 changes: 3 additions & 1 deletion optimistix/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class AbstractSearch(eqx.Module, Generic[Y, _FnInfo, _FnEvalInfo, SearchState]):
See [this documentation](./introduction.md) for more information.
"""

_needs_grad_at_y_eval: ClassVar[bool]

@abc.abstractmethod
def init(self, y: Y, f_info_struct: _FnInfo) -> SearchState:
"""Is called just once, at the very start of the entire optimisation problem.
Expand Down Expand Up @@ -363,7 +365,7 @@ def step(
- `f_info`: An [`optimistix.FunctionInfo`][] describing information about `f`
evaluated at `y`, the gradient of `f` at `y`, etc.
- `f_eval_info`: An [`optimistix.FunctionInfo`][] describing information about
`f` evaluated at `y`, the gradient of `f` at `y`, etc.
`f` evaluated at `y_eval`, the gradient of `f` at `y_eval`, etc.
- `state`: the evolving state of the repeated searches.

**Returns:**
Expand Down
1 change: 1 addition & 0 deletions optimistix/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
ClassicalTrustRegion as ClassicalTrustRegion,
LinearTrustRegion as LinearTrustRegion,
)
from .zoom import Zoom as Zoom
3 changes: 2 additions & 1 deletion optimistix/_solver/backtracking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast, TypeAlias
from typing import cast, ClassVar, TypeAlias

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -29,6 +29,7 @@ class BacktrackingArmijo(AbstractSearch[Y, _FnInfo, _FnEvalInfo, _BacktrackingSt
decrease_factor: ScalarLike = 0.5
slope: ScalarLike = 0.1
step_init: ScalarLike = 1.0
_needs_grad_at_y_eval: ClassVar[bool] = False

def __post_init__(self):
self.decrease_factor = eqx.error_if(
Expand Down
24 changes: 19 additions & 5 deletions optimistix/_solver/gradient_methods.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, Generic, TypeAlias
from typing import Any, cast, Generic, TypeAlias

import equinox as eqx
import jax
Expand Down Expand Up @@ -131,7 +131,9 @@ class AbstractGradientDescent(AbstractMinimiser[Y, Aux, _GradientDescentState]):
norm: AbstractVar[Callable[[PyTree], Scalar]]
descent: AbstractVar[AbstractDescent[Y, FunctionInfo.EvalGrad, Any]]
search: AbstractVar[
AbstractSearch[Y, FunctionInfo.EvalGrad, FunctionInfo.Eval, Any]
AbstractSearch[
Y, FunctionInfo.EvalGrad, FunctionInfo.Eval | FunctionInfo.EvalGrad, Any
]
]

def init(
Expand Down Expand Up @@ -170,19 +172,31 @@ def step(
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), state.y_eval, has_aux=True
)

if self.search._needs_grad_at_y_eval:
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
else:
f_eval_info = FunctionInfo.Eval(f_eval)

step_size, accept, search_result, search_state = self.search.step(
state.first_step,
y,
state.y_eval,
state.f_info,
FunctionInfo.Eval(f_eval),
f_eval_info,
state.search_state,
)

def accepted(descent_state):
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
nonlocal f_eval_info

if not self.search._needs_grad_at_y_eval:
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)

f_eval_info = cast(FunctionInfo.EvalGrad, f_eval_info)

f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
descent_state = self.descent.query(state.y_eval, f_eval_info, descent_state)
y_diff = (state.y_eval**ω - y**ω).ω
f_diff = (f_eval**ω - state.f_info.f**ω).ω
Expand Down
3 changes: 2 additions & 1 deletion optimistix/_solver/learning_rate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import cast, ClassVar

import equinox as eqx
import jax.numpy as jnp
Expand All @@ -16,6 +16,7 @@ def _typed_asarray(x: ScalarLike) -> Array:
class LearningRate(AbstractSearch[Y, FunctionInfo, FunctionInfo, None]):
"""Move downhill by taking a step of the fixed size `learning_rate`."""

_needs_grad_at_y_eval: ClassVar[bool] = False
learning_rate: ScalarLike = eqx.field(converter=_typed_asarray)

def init(self, y: Y, f_info_struct: FunctionInfo) -> None:
Expand Down
28 changes: 19 additions & 9 deletions optimistix/_solver/quasi_newton.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from collections.abc import Callable
from typing import Any, Generic, TypeVar
from typing import Any, cast, Generic, TypeVar

import equinox as eqx
import jax
Expand Down Expand Up @@ -30,6 +30,7 @@
from .._solution import RESULTS
from .backtracking import BacktrackingArmijo
from .gauss_newton import NewtonDescent
from .zoom import Zoom


_Hessian = TypeVar(
Expand Down Expand Up @@ -118,10 +119,6 @@ class AbstractQuasiNewton(
structure and the Hessian update state, while the latter is called to compute an
update to the approximation of the Hessian or the inverse Hessian.

Already supported schemes to form inverse Hessian and Hessian approximations are
implemented in `optimistix.AbstractBFGS`, `optimistix.AbstractDFP` and
`optimistix.AbstractLBFGS`.

Supports the following `options`:

- `autodiff_mode`: whether to use forward- or reverse-mode autodifferentiation to
Expand Down Expand Up @@ -210,29 +207,40 @@ def step(
f_eval, lin_fn, aux_eval = jax.linearize(
lambda _y: fn(_y, args), state.y_eval, has_aux=True
)

if self.search._needs_grad_at_y_eval:
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)
else:
f_eval_info = FunctionInfo.Eval(f_eval)

step_size, accept, search_result, search_state = self.search.step(
state.first_step,
y,
state.y_eval,
state.f_info,
FunctionInfo.Eval(f_eval),
f_eval_info, # pyright: ignore # TODO Fix (jhaffner)
state.search_state,
)

def accepted(descent_state):
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
nonlocal f_eval_info

if not self.search._needs_grad_at_y_eval:
grad = lin_to_grad(lin_fn, state.y_eval, autodiff_mode=autodiff_mode)
f_eval_info = FunctionInfo.EvalGrad(f_eval, grad)

f_eval_info, hessian_update_state = self.update_hessian(
y,
state.y_eval,
state.f_info,
FunctionInfo.EvalGrad(f_eval, grad),
cast(FunctionInfo.EvalGrad, f_eval_info),
state.hessian_update_state,
)

descent_state = self.descent.query(
state.y_eval,
f_eval_info, # pyright: ignore
f_eval_info,
descent_state,
)
y_diff = (state.y_eval**ω - y**ω).ω
Expand Down Expand Up @@ -447,6 +455,7 @@ def __init__(
norm: Callable[[PyTree], Scalar] = max_norm,
use_inverse: bool = True,
verbose: frozenset[str] = frozenset(),
search: AbstractSearch = Zoom(initial_guess_strategy="one"),
):
self.rtol = rtol
self.atol = atol
Expand Down Expand Up @@ -606,6 +615,7 @@ def __init__(
norm: Callable[[PyTree], Scalar] = max_norm,
use_inverse: bool = True,
verbose: frozenset[str] = frozenset(),
search: AbstractSearch = Zoom(initial_guess_strategy="one"),
):
self.rtol = rtol
self.atol = atol
Expand Down
3 changes: 2 additions & 1 deletion optimistix/_solver/trust_region.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import TypeAlias, TypeVar
from typing import ClassVar, TypeAlias, TypeVar

import equinox as eqx
import jax.numpy as jnp
Expand Down Expand Up @@ -46,6 +46,7 @@ class _AbstractTrustRegion(AbstractSearch[Y, _FnInfo, _FnEvalInfo, _TrustRegionS
low_cutoff: AbstractVar[ScalarLike]
high_constant: AbstractVar[ScalarLike]
low_constant: AbstractVar[ScalarLike]
_needs_grad_at_y_eval: ClassVar[bool] = False

def __post_init__(self):
# You would not expect `self.low_cutoff` or `self.high_cutoff` to
Expand Down
Loading