@@ -158,21 +158,15 @@ function fitpredictneigh(model, dat, dom, point, prob, minneighbors, maxneighbor
158
158
KBallSearch (domain (dat), maxneighbors, neighborhood)
159
159
end
160
160
161
- # pre-allocate memory for neighbors
162
- neighbors = [Vector {Int} (undef, maxneighbors) for _ in 1 : Threads. nthreads ()]
163
-
164
- # prediction at index
165
- function prediction (ind)
166
- # neighbors in current thread
167
- tneighbors = neighbors[Threads. threadid ()]
168
-
161
+ # prediction of single index
162
+ function predictsingle! (neighbors, ind)
169
163
# find neighbors with data
170
- n = search! (tneighbors , centroid (dom, ind), searcher)
164
+ n = search! (neighbors , centroid (dom, ind), searcher)
171
165
172
166
# predict if enough neighbors
173
167
vals = if n ≥ minneighbors
174
168
# view neighborhood with data
175
- ninds = view (tneighbors , 1 : n)
169
+ ninds = view (neighbors , 1 : n)
176
170
ndata = view (dat, ninds)
177
171
178
172
# fit model with neighborhood
@@ -189,11 +183,14 @@ function fitpredictneigh(model, dat, dom, point, prob, minneighbors, maxneighbor
189
183
(; zip (vars, vals)... )
190
184
end
191
185
186
+ # pre-allocate memory for neighbors
187
+ state = [Vector {Int} (undef, maxneighbors) for _ in 1 : Threads. nthreads ()]
188
+
192
189
# perform prediction
193
190
preds = if isthreaded ()
194
- _predictionthread (prediction , inds)
191
+ _predictionthread! (state, predictsingle! , inds)
195
192
else
196
- _predictionserial (prediction , inds)
193
+ _predictionserial! (state, predictsingle! , inds)
197
194
end
198
195
199
196
# convert to original table type
@@ -212,39 +209,45 @@ function fitpredictfull(model, dat, dom, point, prob)
212
209
vars = Tables. columnnames (cols)
213
210
inds = 1 : nelements (dom)
214
211
215
- # fit model to data
216
- fmodel = fit (model, dat)
217
- fitted = [deepcopy (fmodel) for _ in 1 : Threads. nthreads ()]
218
-
219
- # prediction at index
220
- function prediction (ind)
221
- fmod = fitted[Threads. threadid ()]
212
+ # prediction of single index
213
+ function predictsingle! (fmodel, ind)
222
214
geom = getgeom (dom, ind)
223
- vals = predfun (fmod , vars, geom)
215
+ vals = predfun (fmodel , vars, geom)
224
216
(; zip (vars, vals)... )
225
217
end
226
218
219
+ # fit model to data
220
+ fmodel = fit (model, dat)
221
+
222
+ # copy fitted model to all threads
223
+ state = [deepcopy (fmodel) for _ in 1 : Threads. nthreads ()]
224
+
227
225
# perform prediction
228
226
preds = if isthreaded ()
229
- _predictionthread (prediction , inds)
227
+ _predictionthread! (state, predictsingle! , inds)
230
228
else
231
- _predictionserial (prediction , inds)
229
+ _predictionserial! (state, predictsingle! , inds)
232
230
end
233
231
234
232
# convert to original table type
235
233
preds |> Tables. materializer (values (dat))
236
234
end
237
235
238
- _predictionserial (prediction, inds) = map (prediction, inds)
239
-
240
- function _predictionthread (prediction, inds)
241
- preds = Vector {Any} (undef, length (inds))
242
- Threads. @threads for ind in inds
243
- preds[ind] = prediction (ind)
236
+ function _predictionthread! (state, predictsingle!, inds)
237
+ buffer = Vector {Any} (undef, length (inds))
238
+ chunks = index_chunks (inds, n= Threads. nthreads ())
239
+ Threads. @sync for (cind, chunk) in enumerate (chunks)
240
+ Threads. @spawn begin
241
+ for ind in chunk
242
+ buffer[ind] = predictsingle! (state[cind], ind)
243
+ end
244
+ end
244
245
end
245
- map (identity, preds )
246
+ map (identity, buffer )
246
247
end
247
248
249
+ _predictionserial! (state, predictsingle!, inds) = map (ind -> predictsingle! (first (state), ind), inds)
250
+
248
251
# ----------------
249
252
# IMPLEMENTATIONS
250
253
# ----------------
0 commit comments