-
Notifications
You must be signed in to change notification settings - Fork 74
Open
Description
Full error message: https://gist.github.com/penelopeysm/b106d30daf38a536913f69aa113532cc
This one also only crashes with reverse-mode on 1.11. It runs with either forward-mode and/or 1.10.
struct IL{T}
idxs::T
end
struct VN{sym,T<:Union{typeof(identity),IL}}
inner::T
function VN{sym}(inner) where {sym}
return new{sym,typeof(inner)}(inner)
end
end
function Base.:(==)(x::VN{sym1}, y::VN{sym2}) where {sym1, sym2}
return sym1 == sym2 && x.inner == y.inner
end
struct VI{Tmd,Tlogp}
metadata::Tmd
logp::Tlogp
end
struct TSVI{V}
varinfo::V
end
struct MyMetadata{
TIdcs<:Dict{<:VN,Int},
TVal<:AbstractVector{<:Real},
}
idcs::TIdcs
vals::TVal
end
function g(vi, z)
s = Vector{Float64}(undef, length(z))
for i in eachindex(s)
s[i] = vi.varinfo.metadata.s.vals[1]
end
return VI(vi.varinfo.metadata, 1.0)
end
params = [0.0]
function f(x)
md_s = MyMetadata(
Dict{VN,Int}(VN{:s}(IL((1,))) => 1),
x,
)
vi = TSVI(VI((; s=md_s), 0.0))
vi = g(vi, [1.5, 2.0])
return vi.logp
end
f(params)
import Enzyme: Enzyme, Reverse, set_runtime_activity
Enzyme.gradient(set_runtime_activity(Reverse), f, params)
Metadata
Metadata
Assignees
Labels
No labels