-
Notifications
You must be signed in to change notification settings - Fork 29
Description
I'd really like to be able to differentiate complex functions with ForwardDiff. For almost all analytic functions, the derivative can be taken in the same way as in the real case, so Julia's dynamic type system should make complex diferentiation work out of the box. It seems to me that the only reason it doesn't is because many functions in ForwardDiff
accept Real
arguments, but if we simply changed these to accept Number
s then we should be most of the way there. Functions that are not analytic such as real
, imag
, abs
, conj
, abs2
, reim
should all return errors when their derivatives are taken at a complex input. For example, consider these two functions:
f(z) = z^2
g(z) = Complex(real(z)^2 - imag(z)^2, 2 * real(z) * imag(z))
These functions are equal and both analytic, but it is not obvious by looking at g
that it is analytic. In my opinion, it would be acceptable if derivative(f, im)
worked while derivative(g, im)
threw an error.
I already started experimenting with this. I copied the definitions of several functions in ForwardDiff
and just replaced Real
with Complex
, like so:
ForwardDiff.can_dual(::Type{Complex{T}}) where {T<:Real} = true
@inline function ForwardDiff.derivative(f::F, x::R) where {F,R<:Complex}
T = typeof(Tag(f, R))
return extract_derivative(T, f(Dual{T}(x, one(x))))
end
@inline function Base.:*(partials::Partials, x::Complex)
return Partials(scale_tuple(partials.values, x))
end
@inline function Base.:/(partials::Partials, x::Complex)
return Partials(div_tuple_by_scalar(partials.values, x))
end
@inline Base.:*(x::Complex, partials::Partials) = partials * x
@inline Base.:*(partials::Partials{0,V}, x::Complex) where {V} = Partials{0,promote_type(V, typeof(x))}(tuple())
@inline Base.:*(x::Complex, partials::Partials{0,V}) where {V} = Partials{0,promote_type(V, typeof(x))}(tuple())
@inline Base.:/(partials::Partials{0,V}, x::Complex) where {V} = Partials{0,promote_type(V, typeof(x))}(tuple())
This was enough for me to correctly differentiate sin
. I expect that if we just do this throughout ForwardDiff
(and replace the Real
definitions with Number
, instead of adding Complex
definitions like I did here), then we'll be most of the way there. Here's my plan:
- Find and replace
Real
withNumber
throughout the repository - Tweak things until all the test cases pass
- Add test cases for complex functions
I'm tempted to actually do this and make a pull request as soon as I have some time to do so.