@@ -382,6 +382,7 @@ Base.show(io::IO, ::MIME"text/plain", e::TimedNLPModel) = Base.print(io, e);
382382struct CompressedNLPModel{
383383 T,
384384 VT<: AbstractVector{T} ,
385+ B,
385386 VI<: AbstractVector{Int} ,
386387 VI2<: AbstractVector{Tuple{Tuple{Int,Int},Int}} ,
387388 M<: NLPModels.AbstractNLPModel{T,VT} ,
@@ -394,41 +395,43 @@ struct CompressedNLPModel{
394395 hsparsity:: VI2
395396 buffer:: VT
396397
398+ backend:: B
397399 meta:: NLPModels.NLPModelMeta{T,VT}
398400 counters:: NLPModels.Counters
399401end
400402
401- function getptr (array)
403+ function getptr (backend :: Nothing , array; cmp = (x,y) -> x != y )
402404 return push! (
403405 pushfirst! (
404- findall (_is_sparsity_not_equal .(@view (array[1 : end - 1 ]), @view (array[2 : end ]))) .+ =
406+ findall (cmp .(@view (array[1 : end - 1 ]), @view (array[2 : end ]))) .+ =
405407 1 ,
406408 1 ,
407409 ),
408410 length (array) + 1 ,
409411 )
410412end
411- _is_sparsity_not_equal (a, b) = first (a) != first (b)
412413
413414function CompressedNLPModel (m)
414415
415416 nnzj = NLPModels. get_nnzj (m)
416- Ibuffer = Vector {Int} (undef , nnzj)
417- Jbuffer = Vector {Int} (undef , nnzj)
417+ Ibuffer = similar (m . meta . x0, Int , nnzj)
418+ Jbuffer = similar (m . meta . x0, Int , nnzj)
418419 NLPModels. jac_structure! (m, Ibuffer, Jbuffer)
419420
420- jsparsity = map ((k, i, j) -> ((j, i), k), 1 : nnzj, Ibuffer, Jbuffer)
421+ backend = getbackend (m)
422+
423+ jsparsity = get_compressed_sparsity (nnzj, Ibuffer, Jbuffer, backend)
421424 sort! (jsparsity; lt = (a, b) -> a[1 ] < b[1 ])
422- jptr = getptr (jsparsity)
425+ jptr = getptr (backend, jsparsity; cmp = (a, b) -> first (a) != first (b) )
423426
424427 nnzh = NLPModels. get_nnzh (m)
425428 resize! (Ibuffer, nnzh)
426429 resize! (Jbuffer, nnzh)
427430 NLPModels. hess_structure! (m, Ibuffer, Jbuffer)
428431
429- hsparsity = map ((k, i, j) -> ((j, i), k), 1 : nnzh, Ibuffer, Jbuffer)
432+ hsparsity = get_compressed_sparsity ( nnzh, Ibuffer, Jbuffer, backend )
430433 sort! (hsparsity; lt = (a, b) -> a[1 ] < b[1 ])
431- hptr = getptr (hsparsity)
434+ hptr = getptr (backend, hsparsity; cmp = (a, b) -> first (a) != first (b) )
432435
433436 buffer = similar (m. meta. x0, max (nnzj, nnzh))
434437
@@ -447,9 +450,12 @@ function CompressedNLPModel(m)
447450
448451 counters = NLPModels. Counters ()
449452
450- return CompressedNLPModel (m, jptr, jsparsity, hptr, hsparsity, buffer, meta, counters)
453+ return CompressedNLPModel (m, jptr, jsparsity, hptr, hsparsity, buffer, backend, meta, counters)
451454end
452455
456+ getbackend (m) = nothing
457+ get_compressed_sparsity (nnz, Ibuffer, Jbuffer, backend:: Nothing ) = map ((k, i, j) -> ((j, i), k), 1 : nnz, Ibuffer, Jbuffer)
458+
453459function NLPModels. obj (m:: CompressedNLPModel , x:: AbstractVector )
454460 NLPModels. obj (m. inner, x)
455461end
464470
465471function NLPModels. jac_coord! (m:: CompressedNLPModel , x:: AbstractVector , j:: AbstractVector )
466472 NLPModels. jac_coord! (m. inner, x, m. buffer)
467- _compress! (j, m. buffer, m. jptr, m. jsparsity)
473+ _compress! (j, m. buffer, m. jptr, m. jsparsity, m . backend )
468474end
469475
470476function NLPModels. hess_coord! (
@@ -475,26 +481,26 @@ function NLPModels.hess_coord!(
475481 obj_weight = 1.0 ,
476482)
477483 NLPModels. hess_coord! (m. inner, x, y, m. buffer; obj_weight = obj_weight)
478- _compress! (h, m. buffer, m. hptr, m. hsparsity)
484+ _compress! (h, m. buffer, m. hptr, m. hsparsity, m . backend )
479485end
480486
481487function NLPModels. jac_structure! (
482488 m:: CompressedNLPModel ,
483489 I:: AbstractVector ,
484490 J:: AbstractVector ,
485491)
486- _structure! (I, J, m. jptr, m. jsparsity)
492+ _structure! (I, J, m. jptr, m. jsparsity, m . backend )
487493end
488494
489495function NLPModels. hess_structure! (
490496 m:: CompressedNLPModel ,
491497 I:: AbstractVector ,
492498 J:: AbstractVector ,
493499)
494- _structure! (I, J, m. hptr, m. hsparsity)
500+ _structure! (I, J, m. hptr, m. hsparsity, m . backend )
495501end
496502
497- function _compress! (V, buffer, ptr, sparsity)
503+ function _compress! (V, buffer, ptr, sparsity, backend :: Nothing )
498504 fill! (V, zero (eltype (V)))
499505 @simd for i = 1 : length (ptr)- 1
500506 for j = ptr[i]: ptr[i+ 1 ]- 1
@@ -503,7 +509,7 @@ function _compress!(V, buffer, ptr, sparsity)
503509 end
504510end
505511
506- function _structure! (I, J, ptr, sparsity)
512+ function _structure! (I, J, ptr, sparsity, backend :: Nothing )
507513 @simd for i = 1 : length (ptr)- 1
508514 J[i], I[i] = sparsity[ptr[i]][1 ]
509515 end
0 commit comments