Skip to content

Commit bf4b6b5

Browse files
authored
Fix multi-threaded interpolation (#34)
1 parent 1b84167 commit bf4b6b5

File tree

3 files changed

+35
-29
lines changed

3 files changed

+35
-29
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Júlio Hoffimann <[email protected]> and contributors"]
44
version = "0.12.5"
55

66
[deps]
7+
ChunkSplitters = "ae650224-84b6-46f8-82ea-d812ca08434e"
78
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
89
CoordRefSystems = "b46f11dc-f210-4604-bfba-323c1ec968cb"
910
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
@@ -17,6 +18,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1718
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
1819

1920
[compat]
21+
ChunkSplitters = "3.1"
2022
Combinatorics = "1.0"
2123
CoordRefSystems = "0.18"
2224
Distances = "0.10"

src/GeoStatsModels.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ using Distances
1616
using Unitful
1717
using Tables
1818

19+
using ChunkSplitters: index_chunks
1920
using LinearAlgebra: QRIteration
2021

2122
include("models.jl")

src/models.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -158,21 +158,15 @@ function fitpredictneigh(model, dat, dom, point, prob, minneighbors, maxneighbor
158158
KBallSearch(domain(dat), maxneighbors, neighborhood)
159159
end
160160

161-
# pre-allocate memory for neighbors
162-
neighbors = [Vector{Int}(undef, maxneighbors) for _ in 1:Threads.nthreads()]
163-
164-
# prediction at index
165-
function prediction(ind)
166-
# neighbors in current thread
167-
tneighbors = neighbors[Threads.threadid()]
168-
161+
# prediction of single index
162+
function predictsingle!(neighbors, ind)
169163
# find neighbors with data
170-
n = search!(tneighbors, centroid(dom, ind), searcher)
164+
n = search!(neighbors, centroid(dom, ind), searcher)
171165

172166
# predict if enough neighbors
173167
vals = if n minneighbors
174168
# view neighborhood with data
175-
ninds = view(tneighbors, 1:n)
169+
ninds = view(neighbors, 1:n)
176170
ndata = view(dat, ninds)
177171

178172
# fit model with neighborhood
@@ -189,11 +183,14 @@ function fitpredictneigh(model, dat, dom, point, prob, minneighbors, maxneighbor
189183
(; zip(vars, vals)...)
190184
end
191185

186+
# pre-allocate memory for neighbors
187+
state = [Vector{Int}(undef, maxneighbors) for _ in 1:Threads.nthreads()]
188+
192189
# perform prediction
193190
preds = if isthreaded()
194-
_predictionthread(prediction, inds)
191+
_predictionthread!(state, predictsingle!, inds)
195192
else
196-
_predictionserial(prediction, inds)
193+
_predictionserial!(state, predictsingle!, inds)
197194
end
198195

199196
# convert to original table type
@@ -212,39 +209,45 @@ function fitpredictfull(model, dat, dom, point, prob)
212209
vars = Tables.columnnames(cols)
213210
inds = 1:nelements(dom)
214211

215-
# fit model to data
216-
fmodel = fit(model, dat)
217-
fitted = [deepcopy(fmodel) for _ in 1:Threads.nthreads()]
218-
219-
# prediction at index
220-
function prediction(ind)
221-
fmod = fitted[Threads.threadid()]
212+
# prediction of single index
213+
function predictsingle!(fmodel, ind)
222214
geom = getgeom(dom, ind)
223-
vals = predfun(fmod, vars, geom)
215+
vals = predfun(fmodel, vars, geom)
224216
(; zip(vars, vals)...)
225217
end
226218

219+
# fit model to data
220+
fmodel = fit(model, dat)
221+
222+
# copy fitted model to all threads
223+
state = [deepcopy(fmodel) for _ in 1:Threads.nthreads()]
224+
227225
# perform prediction
228226
preds = if isthreaded()
229-
_predictionthread(prediction, inds)
227+
_predictionthread!(state, predictsingle!, inds)
230228
else
231-
_predictionserial(prediction, inds)
229+
_predictionserial!(state, predictsingle!, inds)
232230
end
233231

234232
# convert to original table type
235233
preds |> Tables.materializer(values(dat))
236234
end
237235

238-
_predictionserial(prediction, inds) = map(prediction, inds)
239-
240-
function _predictionthread(prediction, inds)
241-
preds = Vector{Any}(undef, length(inds))
242-
Threads.@threads for ind in inds
243-
preds[ind] = prediction(ind)
236+
function _predictionthread!(state, predictsingle!, inds)
237+
buffer = Vector{Any}(undef, length(inds))
238+
chunks = index_chunks(inds, n=Threads.nthreads())
239+
Threads.@sync for (cind, chunk) in enumerate(chunks)
240+
Threads.@spawn begin
241+
for ind in chunk
242+
buffer[ind] = predictsingle!(state[cind], ind)
243+
end
244+
end
244245
end
245-
map(identity, preds)
246+
map(identity, buffer)
246247
end
247248

249+
_predictionserial!(state, predictsingle!, inds) = map(ind -> predictsingle!(first(state), ind), inds)
250+
248251
# ----------------
249252
# IMPLEMENTATIONS
250253
# ----------------

0 commit comments

Comments
 (0)