Skip to content

Commit 0b4977a

Browse files
authored
Set more sensible default threads count (#196)
* Set more sensible default threads count * Don't care about pre 1.6 * Typo * add Nothing back into the union
1 parent 07b3a57 commit 0b4977a

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

src/booster.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ function Booster(cache::AbstractVector{<:DMatrix};
108108
model_file::AbstractString="",
109109
tree_method::Union{Nothing,AbstractString}=nothing,
110110
validate_parameters::Bool=true,
111+
nthread=Threads.nthreads(),
111112
kw...
112113
)
113114
o = Ref{BoosterHandle}()
@@ -124,7 +125,7 @@ function Booster(cache::AbstractVector{<:DMatrix};
124125
else
125126
(tree_method=tree_method,)
126127
end
127-
setparams!(b; validate_parameters, tm..., kw...)
128+
setparams!(b; validate_parameters, nthread, tm..., kw...)
128129
b
129130
end
130131
Booster(dm::DMatrix; kw...) = Booster([dm]; kw...)

src/dmatrix.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,13 +581,13 @@ function _unsafe_dataiter_reset(ptr::Ptr)
581581
end
582582

583583
function _dmatrix_caching_config_json(;cache_prefix::AbstractString,
584-
nthreads::Union{Integer,Nothing},
584+
nthreads::Union{Integer, Nothing},
585585
missing_value::Float32=NaN32,
586586
)
587587
d = Dict("missing"=>"__NAN_STR__",
588588
"cache_prefix"=>cache_prefix,
589589
)
590-
isnothing(nthreads) || (d["nthreads"] = nthreads)
590+
isnothing(nthreads) || (d["nthreads"] = string(nthreads))
591591
# this is to strip out the special Float32 values to representations it'll accept
592592
nanstr = if isnan(missing_value)
593593
"NaN"
@@ -603,7 +603,7 @@ end
603603
function DMatrix(itr::DataIterator;
604604
missing_value::Float32=NaN32,
605605
cache_prefix::AbstractString=joinpath(tempdir(),"xgb-cache"),
606-
nthreads::Union{Integer,Nothing}=nothing,
606+
nthreads::Union{Integer, Nothing}=Threads.nthreads(),
607607
kw...
608608
)
609609
o = Ref{DMatrixHandle}()

0 commit comments

Comments
 (0)