Skip to content

Commit c0807e5

Browse files
committed
fixed oneapi and amdgpu issues
1 parent 684dad4 commit c0807e5

File tree

4 files changed

+14
-37
lines changed

4 files changed

+14
-37
lines changed

ext/ExaModelsAMDGPU.jl

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,4 @@ import ExaModels, AMDGPU
44

55
ExaModels.convert_array(v, backend::AMDGPU.ROCBackend) = AMDGPU.ROCArray(v)
66

7-
ExaModels.sort!(array::A; lt = isless) where {A<:AMDGPU.ROCVector} =
8-
copyto!(array, sort!(Array(array); lt = lt))
9-
10-
# Below are type piracy
11-
function Base.findall(f::F, bitarray::A) where {F<:Function,A<:AMDGPU.ROCVector}
12-
a = Array(bitarray)
13-
b = findall(f, a)
14-
c = similar(bitarray, eltype(b), length(b))
15-
16-
return copyto!(c, b)
17-
end
18-
Base.findall(bitarray::A) where {A<:AMDGPU.ROCVector} = Base.findall(identity, bitarray)
197
end

ext/ExaModelsCUDA.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module ExaModelsCUDA
33
import ExaModels: ExaModels, NLPModels
44
import CUDA: CUDA, CUDABackend, CuArray
55

6-
ExaModels.ExaCore(backend::CUDABackend) = ExaModels.ExaCore(Float64, backend)
76
ExaModels.convert_array(v, backend::CUDABackend) = CuArray(v)
87

98
end

ext/ExaModelsOneAPI.jl

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@ function ExaModels.append!(
77
a::A,
88
b::Base.Generator{UnitRange{I}},
99
lb,
10-
) where {I,A<:oneAPI.oneVector}
10+
) where {I,A<:oneAPI.oneArray}
1111
la = length(a)
1212
aa = similar(a, la + lb)
1313
copyto!(view(aa, 1:la), a)
1414
map!(b.f, view(aa, (la+1):(la+lb)), b.iter)
1515
return aa
1616
end
1717

18-
function ExaModels.append!(backend, a::A, b::Base.Generator, lb) where {A<:oneAPI.oneVector}
18+
function ExaModels.append!(backend, a::A, b::Base.Generator, lb) where {A<:oneAPI.oneArray}
1919
la = length(a)
2020
aa = similar(a, la + lb)
2121
copyto!(view(aa, 1:la), a)
@@ -29,7 +29,7 @@ function ExaModels.append!(
2929
a::A,
3030
b::V,
3131
lb,
32-
) where {A<:oneAPI.oneVector,V<:AbstractVector}
32+
) where {A<:oneAPI.oneArray,V<:AbstractArray}
3333
la = length(a)
3434
aa = similar(a, la + lb)
3535
copyto!(view(aa, 1:la), a)
@@ -38,7 +38,7 @@ function ExaModels.append!(
3838
end
3939

4040

41-
function ExaModels.append!(backend, a::A, b::Number, lb) where {A<:oneAPI.oneVector}
41+
function ExaModels.append!(backend, a::A, b::Number, lb) where {A<:oneAPI.oneArray}
4242
la = length(a)
4343
aa = similar(a, la + lb)
4444
copyto!(view(aa, 1:la), a)
@@ -48,16 +48,6 @@ end
4848

4949
ExaModels.convert_array(v, backend::oneAPI.oneAPIBackend) = oneAPI.oneArray(v)
5050

51-
ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneVector} =
51+
ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneArray} =
5252
copyto!(array, sort!(Array(array); lt = lt))
53-
54-
# below is type piracy
55-
function Base.findall(f::F, bitarray::A) where {F<:Function,A<:oneAPI.oneVector}
56-
a = Array(bitarray)
57-
b = findall(f, a)
58-
c = similar(bitarray, eltype(b), length(b))
59-
return copyto!(c, b)
60-
end
61-
Base.findall(bitarray::A) where {A<:oneAPI.oneVector} = Base.findall(identity, bitarray)
62-
6353
end # module

src/nlp.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -690,20 +690,20 @@ end
690690
f.o0 + f.f.first(itr[i], nothing)
691691
@inbounds @inline offset0(f::F, itr, i) where {T<:Tuple,P<:Pair{T},F<:SIMDFunction{P}} = f.o0 + idxx(coord(itr, i, f.f.first), Base.size(itr))
692692

693-
idx(itr, I) = @inbounds itr[I]
694-
idx(itr::Base.Iterators.ProductIterator{V}, I) where V = _idx(I-1, itr.iterators, Base.size(itr))
695-
function _idx(n, (vec1, vec...), (si1, si...))
693+
@inline idx(itr, I) = @inbounds itr[I]
694+
@inline idx(itr::Base.Iterators.ProductIterator{V}, I) where V = _idx(I-1, itr.iterators, Base.size(itr))
695+
@inline function _idx(n, (vec1, vec...), (si1, si...))
696696
d, r = divrem(n, si1)
697697
return (vec1[r + 1], _idx(d, vec, si)...)
698698
end
699-
_idx(n, (vec,), ::Tuple{Int}) = @inbounds vec[n + 1]
699+
@inline _idx(n, (vec,), ::Tuple{Int}) = @inbounds vec[n + 1]
700700

701-
idxx(coord, si) = _idxx(coord, si, 1) + 1
702-
_idxx((c,coord...), (s,si...), a) = a * (c - 1) + _idxx(coord, si, a*s)
703-
_idxx(::Tuple{}, ::Tuple{}, a) = 0
701+
@inline idxx(coord, si) = _idxx(coord, si, 1) + 1
702+
@inline _idxx(coord, si, a) = a * (coord[1] - 1) + _idxx(coord[2:end], si[2:end], a*si[1])
703+
@inline _idxx(::Tuple{}, ::Tuple{}, a) = 0
704704

705-
coord(itr, i, (f,fs...)) = (f(idx(itr,i), nothing), coord(itr, i, fs)...)
706-
coord(itr, i, ::Tuple{}) = ()
705+
@inline coord(itr, i, (f,fs...)) = (f(idx(itr,i), nothing), coord(itr, i, fs)...)
706+
@inline coord(itr, i, ::Tuple{}) = ()
707707

708708
for (thing, val) in [(:solution, 1), (:multipliers_L, 0), (:multipliers_U, 2)]
709709
@eval begin

0 commit comments

Comments
 (0)