Skip to content

Commit 298a9a9

Browse files
authored
Support autodiff with DifferentiationInterface (#1131)
* Support autodiff with DifferentiationInterface * Fix versions * Typo * Outdated test code * Structure test sets for easier debugging
1 parent 8e305a4 commit 298a9a9

File tree

7 files changed

+76
-64
lines changed

7 files changed

+76
-64
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ jobs:
1111
strategy:
1212
matrix:
1313
version:
14-
- "1.6"
14+
- "1.10"
1515
- "1"
1616
- "pre"
1717
os:

Project.toml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ ForwardDiff = "0.10"
3030
LineSearches = "7.0.1"
3131
LinearAlgebra = "<0.0.1, 1.6"
3232
MathOptInterface = "1.17"
33-
NLSolversBase = "~7.8.0"
33+
NLSolversBase = "7.9.0"
3434
NaNMath = "0.3.2, 1"
3535
OptimTestProblems = "2.0.3"
3636
Parameters = "0.10, 0.11, 0.12"
@@ -40,9 +40,10 @@ Random = "<0.0.1, 1.6"
4040
SparseArrays = "<0.0.1, 1.6"
4141
StatsBase = "0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
4242
Test = "<0.0.1, 1.6"
43-
julia = "1.6"
43+
julia = "1.10"
4444

4545
[extras]
46+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4647
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4748
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
4849
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
@@ -52,8 +53,9 @@ OptimTestProblems = "cec144fc-5a64-5bc6-99fb-dde8f63e154c"
5253
PositiveFactorizations = "85a6dd25-e78a-55b7-8502-1745935b8125"
5354
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
5455
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
56+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
5557
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
5658
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5759

5860
[targets]
59-
test = ["Test", "Distributions", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "LineSearches", "NLSolversBase", "PositiveFactorizations"]
61+
test = ["Test", "Distributions", "MathOptInterface", "Measurements", "OptimTestProblems", "Random", "RecursiveArrayTools", "StableRNGs", "LineSearches", "NLSolversBase", "PositiveFactorizations", "ReverseDiff", "ADTypes"]

docs/src/user/gradientsandhessians.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ Automatic differentiation techniques are a middle ground between finite differen
1616

1717
Reverse-mode automatic differentiation can be seen as an automatic implementation of the adjoint method mentioned above, and requires a runtime comparable to only one evaluation of ``f``. It is however considerably more complex to implement, requiring to record the execution of the program to then run it backwards, and incurs a larger overhead.
1818

19-
Forward-mode automatic differentiation is supported through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=:forward` keyword to `optimize`. Reverse-mode automatic differentiation is not supported explicitly yet (although you can use it by writing your own `g!` function). There are a number of implementations in Julia, such as [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl).
19+
Forward-mode automatic differentiation is supported through the [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) package by providing the `autodiff=:forward` keyword to `optimize`.
20+
More generic automatic differentiation is supported thanks to [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), by setting `autodiff` to any compatible backend object from [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
21+
For instance, the user can choose `autodiff=AutoReverseDiff()`, `autodiff=AutoEnzyme()`, `autodiff=AutoMooncake()` or `autodiff=AutoZygote()` for a reverse-mode gradient computation, which is generally faster than forward mode on large inputs.
22+
Each of these choices requires loading the corresponding package beforehand.
2023

2124
## Example
2225

src/multivariate/optimize/interface.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -58,26 +58,26 @@ fallback_method(d::OnceDifferentiable) = LBFGS()
5858
fallback_method(d::TwiceDifferentiable) = Newton()
5959

6060
# promote the objective (tuple of callables or an AbstractObjective) according to method requirement
61-
promote_objtype(method, initial_x, autodiff::Symbol, inplace::Bool, args...) = error("No default objective type for $method and $args.")
61+
promote_objtype(method, initial_x, autodiff, inplace::Bool, args...) = error("No default objective type for $method and $args.")
6262
# actual promotions, notice that (args...) captures FirstOrderOptimizer and NonDifferentiable, etc
63-
promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool, args...) = NonDifferentiable(args..., x, real(zero(eltype(x))))
64-
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f) = OnceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
65-
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, args...) = OnceDifferentiable(args..., x, real(zero(eltype(x))); inplace = inplace)
66-
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f, g, h) = OnceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace)
67-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f) = TwiceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
68-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f::NotInplaceObjective) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
69-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f::InplaceObjective) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
70-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f::NLSolversBase.InPlaceObjectiveFGHv) = TwiceDifferentiableHV(f, x)
71-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f::NLSolversBase.InPlaceObjectiveFG_Hv) = TwiceDifferentiableHV(f, x)
72-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f, g) = TwiceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace, autodiff = autodiff)
73-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, f, g, h) = TwiceDifferentiable(f, g, h, x, real(zero(eltype(x))); inplace = inplace)
63+
promote_objtype(method::ZerothOrderOptimizer, x, autodiff, inplace::Bool, args...) = NonDifferentiable(args..., x, real(zero(eltype(x))))
64+
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, f) = OnceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
65+
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, args...) = OnceDifferentiable(args..., x, real(zero(eltype(x))); inplace = inplace)
66+
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, f, g, h) = OnceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace)
67+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f) = TwiceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
68+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f::NotInplaceObjective) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
69+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f::InplaceObjective) = TwiceDifferentiable(f, x, real(zero(eltype(x))))
70+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f::NLSolversBase.InPlaceObjectiveFGHv) = TwiceDifferentiableHV(f, x)
71+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f::NLSolversBase.InPlaceObjectiveFG_Hv) = TwiceDifferentiableHV(f, x)
72+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f, g) = TwiceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace, autodiff = autodiff)
73+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, f, g, h) = TwiceDifferentiable(f, g, h, x, real(zero(eltype(x))); inplace = inplace)
7474
# no-op
75-
promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool, nd::NonDifferentiable) = nd
76-
promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool, od::OnceDifferentiable) = od
77-
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, od::OnceDifferentiable) = od
78-
promote_objtype(method::ZerothOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
79-
promote_objtype(method::FirstOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
80-
promote_objtype(method::SecondOrderOptimizer, x, autodiff::Symbol, inplace::Bool, td::TwiceDifferentiable) = td
75+
promote_objtype(method::ZerothOrderOptimizer, x, autodiff, inplace::Bool, nd::NonDifferentiable) = nd
76+
promote_objtype(method::ZerothOrderOptimizer, x, autodiff, inplace::Bool, od::OnceDifferentiable) = od
77+
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, od::OnceDifferentiable) = od
78+
promote_objtype(method::ZerothOrderOptimizer, x, autodiff, inplace::Bool, td::TwiceDifferentiable) = td
79+
promote_objtype(method::FirstOrderOptimizer, x, autodiff, inplace::Bool, td::TwiceDifferentiable) = td
80+
promote_objtype(method::SecondOrderOptimizer, x, autodiff, inplace::Bool, td::TwiceDifferentiable) = td
8181

8282
# if no method or options are present
8383
function optimize(f, initial_x::AbstractArray; inplace = true, autodiff = :finite, kwargs...)

src/multivariate/solvers/constrained/ipnewton/ipnewton.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@ end
77

88
Base.summary(::IPNewton) = "Interior Point Newton"
99

10-
promote_objtype(method::IPNewton, x, autodiff::Symbol, inplace::Bool, f::TwiceDifferentiable) = f
11-
promote_objtype(method::IPNewton, x, autodiff::Symbol, inplace::Bool, f) = TwiceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
12-
promote_objtype(method::IPNewton, x, autodiff::Symbol, inplace::Bool, f, g) = TwiceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace, autodiff = autodiff)
13-
promote_objtype(method::IPNewton, x, autodiff::Symbol, inplace::Bool, f, g, h) = TwiceDifferentiable(f, g, h, x, real(zero(eltype(x))); inplace = inplace)
10+
promote_objtype(method::IPNewton, x, autodiff, inplace::Bool, f::TwiceDifferentiable) = f
11+
promote_objtype(method::IPNewton, x, autodiff, inplace::Bool, f) = TwiceDifferentiable(f, x, real(zero(eltype(x))); autodiff = autodiff)
12+
promote_objtype(method::IPNewton, x, autodiff, inplace::Bool, f, g) = TwiceDifferentiable(f, g, x, real(zero(eltype(x))); inplace = inplace, autodiff = autodiff)
13+
promote_objtype(method::IPNewton, x, autodiff, inplace::Bool, f, g, h) = TwiceDifferentiable(f, g, h, x, real(zero(eltype(x))); inplace = inplace)
1414

1515
# TODO: Add support for InitialGuess from LineSearches
1616
"""

test/general/objective_types.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,41 @@
66
for T in (OnceDifferentiable, TwiceDifferentiable)
77
odad1 = T(x->5.0, rand(1); autodiff = :finite)
88
odad2 = T(x->5.0, rand(1); autodiff = :forward)
9+
odad3 = T(x->5.0, rand(1); autodiff = AutoReverseDiff())
910
Optim.gradient!(odad1, rand(1))
1011
Optim.gradient!(odad2, rand(1))
11-
# odad3 = T(x->5., rand(1); autodiff = :reverse)
12+
Optim.gradient!(odad3, rand(1))
1213
@test Optim.gradient(odad1) == [0.0]
1314
@test Optim.gradient(odad2) == [0.0]
14-
# @test odad3.g == [0.0]
15+
@test Optim.gradient(odad3) == [0.0]
1516
end
1617

1718
for a in (1.0, 5.0)
1819
xa = rand(1)
1920
odad1 = OnceDifferentiable(x->a*x[1], xa; autodiff = :finite)
2021
odad2 = OnceDifferentiable(x->a*x[1], xa; autodiff = :forward)
21-
# odad3 = OnceDifferentiable(x->a*x[1], xa; autodiff = :reverse)
22+
odad3 = OnceDifferentiable(x->a*x[1], xa; autodiff = AutoReverseDiff())
2223
Optim.gradient!(odad1, xa)
2324
Optim.gradient!(odad2, xa)
25+
Optim.gradient!(odad3, xa)
2426
@test Optim.gradient(odad1) [a]
2527
@test Optim.gradient(odad2) == [a]
26-
# @test odad3.g == [a]
28+
@test Optim.gradient(odad3) == [a]
2729
end
2830
for a in (1.0, 5.0)
2931
xa = rand(1)
3032
odad1 = OnceDifferentiable(x->a*x[1]^2, xa; autodiff = :finite)
3133
odad2 = OnceDifferentiable(x->a*x[1]^2, xa; autodiff = :forward)
32-
# odad3 = OnceDifferentiable(x->a*x[1]^2, xa; autodiff = :reverse)
34+
odad3 = OnceDifferentiable(x->a*x[1]^2, xa; autodiff = AutoReverseDiff())
3335
Optim.gradient!(odad1, xa)
3436
Optim.gradient!(odad2, xa)
35-
@test Optim.gradient(odad1) 2.0*a*xa
37+
Optim.gradient!(odad3, xa)
38+
@test Optim.gradient(odad1) 2.0*a*xa
3639
@test Optim.gradient(odad2) == 2.0*a*xa
37-
# @test odad3.g == 2.0*a*xa
40+
@test Optim.gradient(odad3) == 2.0*a*xa
3841
end
3942
for dtype in (OnceDifferentiable, TwiceDifferentiable)
40-
for autodiff in (:finite, :forward)
43+
for autodiff in (:finite, :forward, AutoReverseDiff())
4144
differentiable = dtype(x->sum(x), rand(2); autodiff = autodiff)
4245
Optim.value(differentiable)
4346
Optim.value!(differentiable, rand(2))

test/runtests.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ import NLSolversBase: clear!
1313
import LinearAlgebra: norm, diag, I, Diagonal, dot, eigen, issymmetric, mul!
1414
import SparseArrays: normalize!, spdiagm
1515

16+
import ReverseDiff
17+
using ADTypes: AutoReverseDiff
18+
1619
debug_printing = false
1720
test_broken = false
1821

@@ -244,40 +247,41 @@ function run_optim_tests_constrained(method; convergence_exceptions = (),
244247
end
245248
end
246249

247-
248-
@testset "special" begin
249-
for my_test in special_tests
250-
println(my_test)
251-
@time include(my_test)
250+
@testset verbose=true "Optim.jl" begin
251+
@testset "special" begin
252+
@testset for my_test in special_tests
253+
println(my_test)
254+
@time include(my_test)
255+
end
252256
end
253-
end
254-
@testset "general" begin
255-
for my_test in general_tests
256-
println(my_test)
257-
@time include(my_test)
257+
@testset "general" begin
258+
@testset for my_test in general_tests
259+
println(my_test)
260+
@time include(my_test)
261+
end
258262
end
259-
end
260-
@testset "univariate" begin
261-
for my_test in univariate_tests
262-
println(my_test)
263-
@time include(my_test)
263+
@testset "univariate" begin
264+
@testset for my_test in univariate_tests
265+
println(my_test)
266+
@time include(my_test)
267+
end
264268
end
265-
end
266-
@testset "multivariate" begin
267-
for my_test in multivariate_tests
268-
println(my_test)
269-
@time include(my_test)
269+
@testset "multivariate" begin
270+
@testset for my_test in multivariate_tests
271+
println(my_test)
272+
@time include(my_test)
273+
end
270274
end
271-
end
272275

273-
println("Literate examples")
274-
@time include("examples.jl")
276+
println("Literate examples")
277+
@time include("examples.jl")
275278

276-
@testset "show method for options" begin
277-
o = Optim.Options()
278-
@test occursin(" = ", sprint(show, o))
279-
end
279+
@testset "show method for options" begin
280+
o = Optim.Options()
281+
@test occursin(" = ", sprint(show, o))
282+
end
280283

281-
@testset "MOI wrapper" begin
282-
include("MOI_wrapper.jl")
284+
@testset "MOI wrapper" begin
285+
include("MOI_wrapper.jl")
286+
end
283287
end

0 commit comments

Comments
 (0)