Skip to content

Commit 02e1249

Browse files
authored
copy predicted array so we dont overwrite used memory location (#188)
* copy predicted array so we dont overwrite used memory location * use mersenee instead of xoshiro for 1.6 * add predict_nocopy and stop using transpose
1 parent 359175c commit 02e1249

File tree

4 files changed

+43
-23
lines changed

4 files changed

+43
-23
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "XGBoost"
22
uuid = "009559a3-9522-5dbb-924b-0b6ed2b22bb9"
3-
version = "2.3.1"
3+
version = "2.3.2"
44

55
[deps]
66
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

docs/src/api.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Booster
3939
updateone!
4040
update!
4141
predict
42+
predict_nocopy
4243
setparam!
4344
setparams!
4445
getnrounds

src/booster.jl

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -275,29 +275,18 @@ function deserialize(::Type{Booster}, buf::AbstractVector{UInt8}, data=DMatrix[]
275275
deserialize!(b, buf)
276276
end
277277

278-
# sadly this is type unstable because we might return a transpose
279278
"""
280-
predict(b::Booster, data; margin=false, training=false, ntree_limit=0)
279+
predict_nocopy(b::Booster, data; kw...)
281280
282-
Use the model `b` to run predictions on `data`. This will return a `Vector{Float32}` which can be compared
283-
to training or test target data.
284-
285-
If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction.
286-
287-
## Examples
288-
```julia
289-
(X, y) = (randn(100,3), randn(100))
290-
b = xgboost((X, y), 10)
291-
292-
ŷ = predict(b, X)
293-
```
281+
Same as [`predict`](@ref), but the output array is not copied. Data in the array output
282+
by this function may be overwritten by future calls to `predict_nocopy` or `predict`.
294283
"""
295-
function predict(b::Booster, Xy::DMatrix;
296-
margin::Bool=false, # whether to output margin
297-
training::Bool=false,
298-
ntree_lower_limit::Integer=0,
299-
ntree_limit::Integer=0, # 0 corresponds to no limit
300-
)
284+
function predict_nocopy(b::Booster, Xy::DMatrix;
285+
margin::Bool=false, # whether to output margin
286+
training::Bool=false,
287+
ntree_lower_limit::Integer=0,
288+
ntree_limit::Integer=0, # 0 corresponds to no limit
289+
)
301290
opts = Dict("type"=>(margin ? 1 : 0),
302291
"iteration_begin"=>ntree_lower_limit,
303292
"iteration_end"=>ntree_limit,
@@ -309,9 +298,31 @@ function predict(b::Booster, Xy::DMatrix;
309298
o = Ref{Ptr{Cfloat}}()
310299
xgbcall(XGBoosterPredictFromDMatrix, b.handle, Xy.handle, opts, oshape, odim, o)
311300
dims = reverse(unsafe_wrap(Array, oshape[], odim[]))
301+
# this `copy` is needed because libxgboost re-uses the pointer
312302
o = unsafe_wrap(Array, o[], tuple(dims...))
313-
length(dims) > 1 ? transpose(o) : o
303+
length(dims) > 1 ? permutedims(o) : o
314304
end
305+
306+
predict_nocopy(b::Booster, Xy; kw...) = predict_nocopy(b, DMatrix(Xy); kw...)
307+
308+
"""
309+
predict(b::Booster, data; margin=false, training=false, ntree_limit=0)
310+
311+
Use the model `b` to run predictions on `data`. This will return a `Vector{Float32}` which can be compared
312+
to training or test target data.
313+
314+
If `ntree_limit > 0` only the first `ntree_limit` trees will be used in prediction.
315+
316+
## Examples
317+
```julia
318+
(X, y) = (randn(100,3), randn(100))
319+
b = xgboost((X, y), 10)
320+
321+
ŷ = predict(b, X)
322+
```
323+
"""
324+
predict(b::Booster, Xy::DMatrix; kw...) = copy(predict_nocopy(b, Xy; kw...))
325+
315326
predict(b::Booster, Xy; kw...) = predict(b, DMatrix(Xy); kw...)
316327

317328
function evaliter(b::Booster, watch, n::Integer=1)

test/runtests.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ end
179179
@test Term.Panel(bst) isa Term.Panel
180180
end
181181

182-
@testset "Booster Save/Load/Serialize" begin
182+
@testset "Booster" begin
183183
dtrain = XGBoost.load(DMatrix, testfilepath("agaricus.txt.train"))
184184
dtest = XGBoost.load(DMatrix, testfilepath("agaricus.txt.test"))
185185

@@ -217,6 +217,14 @@ end
217217
bst2 = Booster(DMatrix[])
218218
XGBoost.deserialize!(bst2, bin)
219219
@test preds == predict(bst2, dtest)
220+
221+
# libxgboost re-uses the prediction memory location,
222+
# so we are testing to make sure we don't do that
223+
rng = MersenneTwister(999) # note that Xoshiro is not available on 1.6
224+
(X, y) = (randn(rng, 10,2), randn(rng, 10))
225+
b = xgboost((X, y))
226+
= predict(b, X)
227+
@test predict(b, randn(MersenneTwister(998), 10,2))
220228
end
221229

222230
has_cuda() && @testset "cuda" begin

0 commit comments

Comments
 (0)