Skip to content

Commit 0469f92

Browse files
authored
GPU Implementation of hprod! for unconstrained (#140)
1 parent c8bb0e7 commit 0469f92

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

ext/ExaModelsKernelAbstractions.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,44 @@ function ExaModels.hprod!(
419419

420420
return Hv
421421
end
422+
function ExaModels.hprod!(
423+
m::ExaModels.ExaModel{T,VT,E},
424+
x::AbstractVector,
425+
v::AbstractVector,
426+
Hv::AbstractVector;
427+
obj_weight = one(eltype(x)),
428+
) where {T,VT,N<:NamedTuple,E<:KAExtension{T,VT,N}}
429+
430+
if isnothing(m.ext.prodhelper)
431+
error("Prodhelper is not defined. Use ExaModels(c; prod=true) to use hprod!")
432+
end
433+
434+
fill!(Hv, zero(eltype(Hv)))
435+
fill!(m.ext.prodhelper.hessbuffer, zero(eltype(Hv)))
436+
437+
_obj_hess_coord!(m.ext.backend, m.ext.prodhelper.hessbuffer, m.objs, x, obj_weight)
438+
synchronize(m.ext.backend)
439+
kersyspmv(m.ext.backend)(
440+
Hv,
441+
v,
442+
m.ext.prodhelper.hesssparsityi,
443+
m.ext.prodhelper.hessbuffer,
444+
m.ext.prodhelper.hessptri,
445+
ndrange = length(m.ext.prodhelper.hessptri) - 1,
446+
)
447+
synchronize(m.ext.backend)
448+
kersyspmv2(m.ext.backend)(
449+
Hv,
450+
v,
451+
m.ext.prodhelper.hesssparsityj,
452+
m.ext.prodhelper.hessbuffer,
453+
m.ext.prodhelper.hessptrj,
454+
ndrange = length(m.ext.prodhelper.hessptrj) - 1,
455+
)
456+
synchronize(m.ext.backend)
457+
458+
return Hv
459+
end
422460

423461
@kernel function kerspmv(y, @Const(x), @Const(coord), @Const(V), @Const(ptr))
424462
idx = @index(Global)

0 commit comments

Comments
 (0)