Skip to content

test_rrule fails on the pedagogical example #593

@mcognetta

Description

@mcognetta

Following the pedagogical example exactly and testing it with test_rrule fails due to type inference, but when fixing the type instability, a different error comes up. Here is a demonstration:

julia> using ChainRules, ChainRulesCore, ChainRulesTestUtils

julia> struct Foo
           A::Matrix
           c::Float64
       end

julia> function foo_mul(foo::Foo, b::AbstractArray)
           return foo.A * b
       end
foo_mul (generic function with 1 method)

julia> @code_warntype foo_mul(Foo(rand(3,3), 3.0), rand(3,3))
MethodInstance for foo_mul(::Foo, ::Matrix{Float64})
  from foo_mul(foo::Foo, b::AbstractArray) in Main at REPL[3]:1
Arguments
  #self#::Core.Const(foo_mul)
  foo::Foo
  b::Matrix{Float64}
Body::Any
1%1 = Base.getproperty(foo, :A)::Matrix%2 = (%1 * b)::Any
└──      return %2

julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
           y = foo_mul(foo, b)
           function foo_mul_pullback(ȳ)
               f̄ = NoTangent()
               f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
               b̄ = @thunk(foo.A' * ȳ)
               return f̄, f̄oo, b̄
           end
           return y, foo_mul_pullback
       end

julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
test_rrule: foo_mul on Foo,Matrix{Float64}: Error During Test at /home/marco/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  return type Matrix{Float64} does not match inferred return type Any
  Stacktrace:
    [1] error(s::String)
      @ Base ./error.jl:33
    [2] _test_inferred(f::Any, args::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:255
    [3] _test_inferred
   
   ########################
   # I CUT MOST OF THIS OUT
   ########################

   [25] _start()
      @ Base ./client.jl:495
Test Summary:                              | Pass  Error  Total
test_rrule: foo_mul on Foo,Matrix{Float64} |    7      1      8
ERROR: Some tests did not pass: 7 passed, 0 failed, 1 errored, 0 broken.

After fixing the type instability:

julia> struct Foo{T}
           A::Matrix{T}
           c::Float64
       end

julia> function foo_mul(foo::Foo, b::AbstractArray)
           return foo.A * b
       end
foo_mul (generic function with 1 method)

julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo, b::AbstractArray)
           y = foo_mul(foo, b)
           function foo_mul_pullback(ȳ)
               f̄ = NoTangent()
               f̄oo = Tangent{Foo}(; A=* b', c=ZeroTangent())
               b̄ = @thunk(foo.A' * ȳ)
               return f̄, f̄oo, b̄
           end
           return y, foo_mul_pullback
       end

julia> @code_warntype foo_mul(Foo(rand(3,3), 3.0), rand(3,3))
MethodInstance for foo_mul(::Foo{Float64}, ::Matrix{Float64})
  from foo_mul(foo::Foo, b::AbstractArray) in Main at REPL[3]:1
Arguments
  #self#::Core.Const(foo_mul)
  foo::Foo{Float64}
  b::Matrix{Float64}
Body::Matrix{Float64}
1%1 = Base.getproperty(foo, :A)::Matrix{Float64}%2 = (%1 * b)::Matrix{Float64}
└──      return %2


julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
test_rrule: foo_mul on Foo{Float64},Matrix{Float64}: Error During Test at /home/marco/.julia/packages/ChainRulesTestUtils/YbVdW/src/testers.jl:193
  Got exception outside of a @test
  MethodError: no method matching +(::Tangent{Foo{Float64}, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, Float64}}}, ::Tangent{Foo, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, ZeroTangent}}})
  Closest candidates are:
    +(::Any, ::Any, ::Any, ::Any...) at ~/github/julia/usr/share/julia/base/operators.jl:655
    +(::Dict, ::Tangent{P}) where P at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:145
    +(::AbstractThunk, ::Tangent) at ~/.julia/packages/ChainRulesCore/C73ay/src/tangent_arithmetic.jl:122
    ...
  Stacktrace:
    [1] add!!(x::Tangent{Foo{Float64}, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, Float64}}}, y::Tangent{Foo, NamedTuple{(:A, :c), Tuple{Matrix{Float64}, ZeroTangent}}})
      @ ChainRulesCore ~/.julia/packages/ChainRulesCore/C73ay/src/accumulation.jl:7
    [2] _test_add!!_behaviour(acc::Any, val::Any; kwargs::Base.Pairs{Symbol, V, Tuple{Vararg{Symbol, N}}, NamedTuple{names, T}} where {V, N, names, T<:Tuple{Vararg{Any, N}}})
      @ ChainRulesTestUtils ~/.julia/packages/ChainRulesTestUtils/YbVdW/src/check_result.jl:188

   ########################
   # I CUT MOST OF THIS OUT
   ########################

   [24] _start()
      @ Base ./client.jl:495
Test Summary:                                       | Pass  Error  Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} |    6      1      7
ERROR: Some tests did not pass: 6 passed, 0 failed, 1 errored, 0 broken.

I believe that the fix is to change the type signature of the rrule to be aware of T (demonstrated below), but I am still a ChainRules novice (certainly there is a better word for the feeling of "I am but a paper clip on the chain rules ladder"), so I will verify it here before opening a PR.

julia> using ChainRules, ChainRulesCore, ChainRulesTestUtils

julia> struct Foo{T}
           A::Matrix{T}
           c::Float64
       end

julia> function foo_mul(foo::Foo, b::AbstractArray)
           return foo.A * b
       end
foo_mul (generic function with 1 method)

julia> function ChainRulesCore.rrule(::typeof(foo_mul), foo::Foo{T}, b::AbstractArray) where T
           y = foo_mul(foo, b)
           function foo_mul_pullback(ȳ)
               f̄ = NoTangent()
               f̄oo = Tangent{Foo{T}}(; A=* b', c=ZeroTangent())
               b̄ = @thunk(foo.A' * ȳ)
               return f̄, f̄oo, b̄
           end
           return y, foo_mul_pullback
       end

julia> test_rrule(foo_mul, Foo(rand(3,3), 3.0), rand(3,3))
Test Summary:                                       | Pass  Total
test_rrule: foo_mul on Foo{Float64},Matrix{Float64} |   10     10
Test.DefaultTestSet("test_rrule: foo_mul on Foo{Float64},Matrix{Float64}", Any[], 10, false, false)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions