Skip to content

Commit 27f1b67

Browse files
committed
rollback changes in ext packages
1 parent c0807e5 commit 27f1b67

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

ext/ExaModelsAMDGPU.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,16 @@ import ExaModels, AMDGPU
44

55
ExaModels.convert_array(v, backend::AMDGPU.ROCBackend) = AMDGPU.ROCArray(v)
66

7+
ExaModels.sort!(array::A; lt = isless) where {A<:AMDGPU.ROCVector} =
8+
copyto!(array, sort!(Array(array); lt = lt))
9+
10+
# Below are type piracy
11+
function Base.findall(f::F, bitarray::A) where {F<:Function,A<:AMDGPU.ROCVector}
12+
a = Array(bitarray)
13+
b = findall(f, a)
14+
c = similar(bitarray, eltype(b), length(b))
15+
16+
return copyto!(c, b)
17+
end
18+
Base.findall(bitarray::A) where {A<:AMDGPU.ROCVector} = Base.findall(identity, bitarray)
719
end

ext/ExaModelsOneAPI.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,14 @@ ExaModels.convert_array(v, backend::oneAPI.oneAPIBackend) = oneAPI.oneArray(v)
5050

5151
ExaModels.sort!(array::A; lt = isless) where {A<:oneAPI.oneArray} =
5252
copyto!(array, sort!(Array(array); lt = lt))
53+
54+
# below is type piracy
55+
function Base.findall(f::F, bitarray::A) where {F<:Function,A<:oneAPI.oneArray}
56+
a = Array(bitarray)
57+
b = findall(f, a)
58+
c = similar(bitarray, eltype(b), length(b))
59+
return copyto!(c, b)
60+
end
61+
Base.findall(bitarray::A) where {A<:oneAPI.oneArray} = Base.findall(identity, bitarray)
62+
5363
end # module

0 commit comments

Comments
 (0)