Skip to content

Commit 9c343b5

Browse files
authored
compressed gpu (#88)
1 parent eadae25 commit 9c343b5

File tree

2 files changed

+67
-26
lines changed

2 files changed

+67
-26
lines changed

ext/ExaModelsKernelAbstractions.jl

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ ExaModels.ExaCore(T, backend::KernelAbstractions.CPU) =
1414
ExaModels.ExaCore(x0 = zeros(T, 0), backend = backend)
1515
ExaModels.ExaCore(backend::KernelAbstractions.CPU) = ExaModels.ExaCore(backend = backend)
1616

17-
function getptr(backend, array; cmp = isequal)
17+
function ExaModels.getptr(backend, array; cmp = (x,y) -> x != y)
1818

1919
bitarray = similar(array, Bool, length(array) + 1)
2020
kergetptr(backend)(cmp, bitarray, array; ndrange = length(array) + 1)
@@ -48,14 +48,14 @@ function ExaModels.ExaModel(
4848
if !isempty(gsparsity)
4949
ExaModels.sort!(gsparsity; lt = ((i, j), (k, l)) -> i < k)
5050
end
51-
gptr = getptr(c.backend, gsparsity)
51+
gptr = ExaModels.getptr(c.backend, gsparsity; cmp = (x,y) -> x[1] != y[1])
5252

5353
conaugsparsity = similar(c.x0, Tuple{Int,Int}, c.nconaug)
5454
_conaug_structure!(c.backend, c.con, conaugsparsity)
5555
if !isempty(conaugsparsity)
5656
ExaModels.sort!(conaugsparsity; lt = ((i, j), (k, l)) -> i < k)
5757
end
58-
conaugptr = getptr(c.backend, conaugsparsity)
58+
conaugptr = ExaModels.getptr(c.backend, conaugsparsity; cmp = (x,y) -> x[1] != y[1])
5959

6060

6161
if prod
@@ -73,22 +73,22 @@ function ExaModels.ExaModel(
7373
if !isempty(jacsparsityi)
7474
ExaModels.sort!(jacsparsityi; lt = (((i, j), k), ((n, m), l)) -> i < n)
7575
end
76-
jacptri = getptr(c.backend, jacsparsityi; cmp = (x, y) -> x[1] == y[1])
76+
jacptri = ExaModels.getptr(c.backend, jacsparsityi; cmp = (x, y) -> x[1][1] != y[1][1])
7777

7878
if !isempty(jacsparsityj)
7979
ExaModels.sort!(jacsparsityj; lt = (((i, j), k), ((n, m), l)) -> j < m)
8080
end
81-
jacptrj = getptr(c.backend, jacsparsityj; cmp = (x, y) -> x[2] == y[2])
81+
jacptrj = ExaModels.getptr(c.backend, jacsparsityj; cmp = (x, y) -> x[1][2] != y[1][2])
8282

8383

8484
if !isempty(hesssparsityi)
8585
ExaModels.sort!(hesssparsityi; lt = (((i, j), k), ((n, m), l)) -> i < n)
8686
end
87-
hessptri = getptr(c.backend, hesssparsityi; cmp = (x, y) -> x[1] == y[1])
87+
hessptri = ExaModels.getptr(c.backend, hesssparsityi; cmp = (x, y) -> x[1][1] != y[1][1])
8888
if !isempty(hesssparsityj)
8989
ExaModels.sort!(hesssparsityj; lt = (((i, j), k), ((n, m), l)) -> j < m)
9090
end
91-
hessptrj = getptr(c.backend, hesssparsityj; cmp = (x, y) -> x[2] == y[2])
91+
hessptrj = ExaModels.getptr(c.backend, hesssparsityj; cmp = (x, y) -> x[1][2] != y[1][2])
9292

9393
prodhelper = (
9494
jacbuffer = jacbuffer,
@@ -620,15 +620,50 @@ end
620620
elseif I == length(array) + 1
621621
bitarray[I] = true
622622
else
623-
i0, j0 = array[I-1]
624-
i1, j1 = array[I]
623+
i0 = array[I-1]
624+
i1 = array[I]
625625

626-
if !cmp(i0, i1)
626+
if cmp(i0, i1)
627627
bitarray[I] = true
628628
else
629629
bitarray[I] = false
630630
end
631631
end
632632
end
633633

634+
ExaModels.getbackend(m::ExaModels.ExaModel{T,VT,E}) where {T,VT,E<:KAExtension} = m.ext.backend
635+
function ExaModels._compress!(V, buffer, ptr, sparsity, backend)
636+
fill!(V, zero(eltype(V)))
637+
ker_compress!(backend)(V, buffer, ptr, sparsity; ndrange= length(ptr)-1)
638+
synchronize(backend)
639+
end
640+
641+
@kernel function ker_compress!(V, @Const(buffer), @Const(ptr), @Const(sparsity))
642+
i = @index(Global)
643+
@inbounds for j = ptr[i]:ptr[i+1]-1
644+
V[i] += buffer[sparsity[j][2]]
645+
end
646+
end
647+
648+
function ExaModels._structure!(I, J, ptr, sparsity, backend)
649+
ker_structure!(backend)(I, J, ptr, sparsity, ndrange = length(ptr)-1)
650+
synchronize(backend)
651+
end
652+
653+
@kernel function ker_structure!(I, J, @Const(ptr), @Const(sparsity))
654+
i = @index(Global)
655+
@inbounds J[i], I[i] = sparsity[ptr[i]][1]
656+
end
657+
658+
function ExaModels.get_compressed_sparsity(nnz, Ibuffer, Jbuffer, backend)
659+
sparsity = similar(Ibuffer, Tuple{Tuple{Int,Int},Int},nnz)
660+
ker_get_compressed_sparsity(backend)(sparsity, Ibuffer, Jbuffer; ndrange = nnz)
661+
synchronize(backend)
662+
return sparsity
663+
end
664+
@kernel function ker_get_compressed_sparsity(sparsity, @Const(I), @Const(J))
665+
i = @index(Global)
666+
@inbounds sparsity[i] = ((J[i],I[i]), i)
667+
end
668+
634669
end # module ExaModelsKernelAbstractions

src/utils.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,7 @@ Base.show(io::IO, ::MIME"text/plain", e::TimedNLPModel) = Base.print(io, e);
382382
struct CompressedNLPModel{
383383
T,
384384
VT<:AbstractVector{T},
385+
B,
385386
VI<:AbstractVector{Int},
386387
VI2<:AbstractVector{Tuple{Tuple{Int,Int},Int}},
387388
M<:NLPModels.AbstractNLPModel{T,VT},
@@ -394,41 +395,43 @@ struct CompressedNLPModel{
394395
hsparsity::VI2
395396
buffer::VT
396397

398+
backend::B
397399
meta::NLPModels.NLPModelMeta{T,VT}
398400
counters::NLPModels.Counters
399401
end
400402

401-
function getptr(array)
403+
function getptr(backend::Nothing, array; cmp = (x,y) -> x != y)
402404
return push!(
403405
pushfirst!(
404-
findall(_is_sparsity_not_equal.(@view(array[1:end-1]), @view(array[2:end]))) .+=
406+
findall(cmp.(@view(array[1:end-1]), @view(array[2:end]))) .+=
405407
1,
406408
1,
407409
),
408410
length(array) + 1,
409411
)
410412
end
411-
_is_sparsity_not_equal(a, b) = first(a) != first(b)
412413

413414
function CompressedNLPModel(m)
414415

415416
nnzj = NLPModels.get_nnzj(m)
416-
Ibuffer = Vector{Int}(undef, nnzj)
417-
Jbuffer = Vector{Int}(undef, nnzj)
417+
Ibuffer = similar(m.meta.x0, Int, nnzj)
418+
Jbuffer = similar(m.meta.x0, Int, nnzj)
418419
NLPModels.jac_structure!(m, Ibuffer, Jbuffer)
419420

420-
jsparsity = map((k, i, j) -> ((j, i), k), 1:nnzj, Ibuffer, Jbuffer)
421+
backend = getbackend(m)
422+
423+
jsparsity = get_compressed_sparsity(nnzj, Ibuffer, Jbuffer, backend)
421424
sort!(jsparsity; lt = (a, b) -> a[1] < b[1])
422-
jptr = getptr(jsparsity)
425+
jptr = getptr(backend, jsparsity; cmp = (a, b) -> first(a) != first(b))
423426

424427
nnzh = NLPModels.get_nnzh(m)
425428
resize!(Ibuffer, nnzh)
426429
resize!(Jbuffer, nnzh)
427430
NLPModels.hess_structure!(m, Ibuffer, Jbuffer)
428431

429-
hsparsity = map((k, i, j) -> ((j, i), k), 1:nnzh, Ibuffer, Jbuffer)
432+
hsparsity = get_compressed_sparsity(nnzh, Ibuffer, Jbuffer, backend)
430433
sort!(hsparsity; lt = (a, b) -> a[1] < b[1])
431-
hptr = getptr(hsparsity)
434+
hptr = getptr(backend, hsparsity; cmp = (a, b) -> first(a) != first(b))
432435

433436
buffer = similar(m.meta.x0, max(nnzj, nnzh))
434437

@@ -447,9 +450,12 @@ function CompressedNLPModel(m)
447450

448451
counters = NLPModels.Counters()
449452

450-
return CompressedNLPModel(m, jptr, jsparsity, hptr, hsparsity, buffer, meta, counters)
453+
return CompressedNLPModel(m, jptr, jsparsity, hptr, hsparsity, buffer, backend, meta, counters)
451454
end
452455

456+
getbackend(m) = nothing
457+
get_compressed_sparsity(nnz, Ibuffer, Jbuffer, backend::Nothing) = map((k, i, j) -> ((j, i), k), 1:nnz, Ibuffer, Jbuffer)
458+
453459
function NLPModels.obj(m::CompressedNLPModel, x::AbstractVector)
454460
NLPModels.obj(m.inner, x)
455461
end
@@ -464,7 +470,7 @@ end
464470

465471
function NLPModels.jac_coord!(m::CompressedNLPModel, x::AbstractVector, j::AbstractVector)
466472
NLPModels.jac_coord!(m.inner, x, m.buffer)
467-
_compress!(j, m.buffer, m.jptr, m.jsparsity)
473+
_compress!(j, m.buffer, m.jptr, m.jsparsity, m.backend)
468474
end
469475

470476
function NLPModels.hess_coord!(
@@ -475,26 +481,26 @@ function NLPModels.hess_coord!(
475481
obj_weight = 1.0,
476482
)
477483
NLPModels.hess_coord!(m.inner, x, y, m.buffer; obj_weight = obj_weight)
478-
_compress!(h, m.buffer, m.hptr, m.hsparsity)
484+
_compress!(h, m.buffer, m.hptr, m.hsparsity, m.backend)
479485
end
480486

481487
function NLPModels.jac_structure!(
482488
m::CompressedNLPModel,
483489
I::AbstractVector,
484490
J::AbstractVector,
485491
)
486-
_structure!(I, J, m.jptr, m.jsparsity)
492+
_structure!(I, J, m.jptr, m.jsparsity, m.backend)
487493
end
488494

489495
function NLPModels.hess_structure!(
490496
m::CompressedNLPModel,
491497
I::AbstractVector,
492498
J::AbstractVector,
493499
)
494-
_structure!(I, J, m.hptr, m.hsparsity)
500+
_structure!(I, J, m.hptr, m.hsparsity, m.backend)
495501
end
496502

497-
function _compress!(V, buffer, ptr, sparsity)
503+
function _compress!(V, buffer, ptr, sparsity, backend::Nothing)
498504
fill!(V, zero(eltype(V)))
499505
@simd for i = 1:length(ptr)-1
500506
for j = ptr[i]:ptr[i+1]-1
@@ -503,7 +509,7 @@ function _compress!(V, buffer, ptr, sparsity)
503509
end
504510
end
505511

506-
function _structure!(I, J, ptr, sparsity)
512+
function _structure!(I, J, ptr, sparsity, backend::Nothing)
507513
@simd for i = 1:length(ptr)-1
508514
J[i], I[i] = sparsity[ptr[i]][1]
509515
end

0 commit comments

Comments
 (0)