-
Notifications
You must be signed in to change notification settings - Fork 65
Closed
Description
Following the pedagogical example exactly and testing it with test_rrule
fails due to type inference, but when fixing the type instability, a different error comes up. Here is a demonstration:
julia> using ChainRules, ChainRulesCore, ChainRulesTestUtils
julia> struct Foo
A::Matrix
c::Float64
end
julia> function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
foo_mul (generic function with 1 method)
julia> @code_warntype foo_mul(Foo(rand(3,3), 3.0), rand(3,3))
MethodInstance for foo_mul(::Foo, ::Matrix{Float64})
from foo_mul(foo::Foo, b::AbstractArray) in Main at REPL[3]:1
Arguments
#self#::Core.Const(foo_mul)
foo::Foo
b::Matrix{Float64}
Body::Any
1 ─ %1 = Base.getproperty(foo, :A)::Matrix
│ %2 = (%1 * b)::Any
└── return %2
julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
test_rrule: foo_mul on Foo,Matrix{Float64}: Error During Test at /home/marco/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
return type Matrix{Float64} does not match inferred return type Any
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] _test_inferred(f::Any, args::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:255
[3] _test_inferred
########################
# I CUT MOST OF THIS OUT
########################
[25] _start()
@ Base ./client.jl:495
Test Summary: | Pass Error Total
test_rrule: foo_mul on Foo,Matrix{Float64} | 7 1 8
ERROR: Some tests did not pass: 7 passed, 0 failed, 1 errored, 0 broken.
After fixing the type instability:
julia> struct Foo{T}
A::Matrix{T}
c::Float64
end
julia> function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
foo_mul (generic function with 1 method)
julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = Tangent{Foo}(; A=ȳ * b', c=ZeroTangent())
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
julia> @code_warntype foo_mul(Foo(rand(3,3), 3.0), rand(3,3))
MethodInstance for foo_mul(::Foo{Float64}, ::Matrix{Float64})
from foo_mul(foo::Foo, b::AbstractArray) in Main at REPL[3]:1
Arguments
#self#::Core.Const(foo_mul)
foo::Foo{Float64}
b::Matrix{Float64}
Body::Matrix{Float64}
1 ─ %1 = Base.getproperty(foo, :A)::Matrix{Float64}
│ %2 = (%1 * b)::Matrix{Float64}
└── return %2
julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
test_rrule: foo_mul on Foo{Float64},Matrix{Float64}: Error During Test at /home/marco/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
Got exception outside of a @test
MethodError: no method matching +(::Tangent{Foo{Float64}, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, Float64}}}, ::Tangent{Foo, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, ZeroTangent}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at ~/github/julia/usr/share/julia/base/operators.jl:655
+(::Dict, ::Tangent{P}) where P at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:145
+(::AbstractThunk, ::Tangent) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
...
Stacktrace:
[1] add!!(x::Tangent{Foo{Float64}, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, Float64}}}, y::Tangent{Foo, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, ZeroTangent}}})
@ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/accumulation.jl:7
[2] _test_add!!_behaviour(acc::Any, val::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
@ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:188
########################
# I CUT MOST OF THIS OUT
########################
[24] _start()
@ Base ./client.jl:495
Test Summary: | Pass Error Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 6 1 7
ERROR: Some tests did not pass: 6 passed, 0 failed, 1 errored, 0 broken.
I believe that the fix is to change the type signature of the rrule
to be aware of T
(demonstrated below), but I am still a ChainRules novice (certainly there is a better word for the feeling of "I am but a paper clip on the chain rules ladder"), so I will verify it here before opening a PR.
julia> using ChainRules, ChainRulesCore, ChainRulesTestUtils
julia> struct Foo{T}
A::Matrix{T}
c::Float64
end
julia> function foo_mul(foo::Foo, b::AbstractArray)
return foo.A * b
end
foo_mul (generic function with 1 method)
julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo{T}, b::AbstractArray) where T
y = foo_mul(foo, b)
function foo_mul_pullback(ȳ)
f̄ = NoTangent()
f̄oo = Tangent{Foo{T}}(; A=ȳ * b', c=ZeroTangent())
b̄ = @thunk(foo.A' * ȳ)
return f̄, f̄oo, b̄
end
return y, foo_mul_pullback
end
julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
Test Summary: | Pass Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} | 10 10
Test.DefaultTestSet("test_rrule: foo_mul on Foo{Float64},Matrix{Float64}", Any[], 10, false, false)
Metadata
Metadata
Assignees
Labels
No labels