11# Tune imports.
2- import os
3- from typing import Dict , Union , List
2+ from typing import Dict
43
54import ray
65import logging
1110from lightgbm .callback import CallbackEnv
1211
1312from xgboost_ray .session import put_queue
14- from xgboost_ray .util import Unavailable , force_on_current_node
13+ from xgboost_ray .util import force_on_current_node
1514
1615try :
1716 from ray import tune
1817 from ray .tune import is_session_enabled
19- from ray .tune .utils import flatten_dict
18+ from ray .tune .integration .lightgbm import (
19+ TuneReportCallback as OrigTuneReportCallback , _TuneCheckpointCallback
20+ as _OrigTuneCheckpointCallback , TuneReportCheckpointCallback as
21+ OrigTuneReportCheckpointCallback )
2022
2123 TUNE_INSTALLED = True
2224except ImportError :
2527 def is_session_enabled ():
2628 return False
2729
28- flatten_dict = is_session_enabled
2930 TUNE_INSTALLED = False
3031
31- try :
32- from ray .tune .integration .lightgbm import \
33- TuneReportCallback as OrigTuneReportCallback , \
34- _TuneCheckpointCallback as _OrigTuneCheckpointCallback , \
35- TuneReportCheckpointCallback as OrigTuneReportCheckpointCallback
36- except ImportError :
37- TuneReportCallback = _TuneCheckpointCallback = \
38- TuneReportCheckpointCallback = Unavailable
39- OrigTuneReportCallback = _OrigTuneCheckpointCallback = \
40- OrigTuneReportCheckpointCallback = object
41-
42- if not hasattr (OrigTuneReportCallback , "_get_report_dict" ):
43- TUNE_LEGACY = True
44- else :
45- TUNE_LEGACY = False
46-
47- try :
48- from ray .tune import PlacementGroupFactory
49-
50- TUNE_USING_PG = True
51- except ImportError :
52- TUNE_USING_PG = False
53- PlacementGroupFactory = Unavailable
54-
5532
5633class _TuneLGBMRank0Mixin :
5734 """Mixin to allow for dynamic setting of rank so that only
@@ -69,115 +46,8 @@ def is_rank_0(self, val: bool):
6946 self ._is_rank_0 = val
7047
7148
72- if TUNE_LEGACY and TUNE_INSTALLED :
73-
74- class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
75- """Create a callback that reports metrics to Ray Tune."""
76- order = 20
77-
78- def __init__ (
79- self ,
80- metrics : Union [None , str , List [str ], Dict [str , str ]] = None ):
81- if isinstance (metrics , str ):
82- metrics = [metrics ]
83- self ._metrics = metrics
84-
85- def _get_report_dict (self ,
86- evals_log : Dict [str , Dict [str , list ]]) -> dict :
87- result_dict = flatten_dict (evals_log , delimiter = "-" )
88- if not self ._metrics :
89- report_dict = result_dict
90- else :
91- report_dict = {}
92- for key in self ._metrics :
93- if isinstance (self ._metrics , dict ):
94- metric = self ._metrics [key ]
95- else :
96- metric = key
97- report_dict [key ] = result_dict [metric ]
98- return report_dict
99-
100- def _get_eval_result (self , env : CallbackEnv ) -> dict :
101- eval_result = {}
102- for data_name , eval_name , result , _ in env .evaluation_result_list :
103- if data_name not in eval_result :
104- eval_result [data_name ] = {}
105- eval_result [data_name ][eval_name ] = result
106- return eval_result
107-
108- def __call__ (self , env : CallbackEnv ) -> None :
109- if not self .is_rank_0 :
110- return
111- eval_result = self ._get_eval_result (env )
112- report_dict = self ._get_report_dict (eval_result )
113- put_queue (lambda : tune .report (** report_dict ))
114-
115- class _TuneCheckpointCallback (_TuneLGBMRank0Mixin ,
116- _OrigTuneCheckpointCallback ):
117- """LightGBM checkpoint callback"""
118- order = 19
119-
120- def __init__ (self ,
121- filename : str = "checkpoint" ,
122- frequency : int = 5 ,
123- * ,
124- is_rank_0 : bool = False ):
125- self ._filename = filename
126- self ._frequency = frequency
127- self .is_rank_0 = is_rank_0
128-
129- @staticmethod
130- def _create_checkpoint (model : Booster , epoch : int , filename : str ,
131- frequency : int ):
132- if epoch % frequency > 0 :
133- return
134- with tune .checkpoint_dir (step = epoch ) as checkpoint_dir :
135- model .save_model (os .path .join (checkpoint_dir , filename ))
136-
137- def __call__ (self , env : CallbackEnv ) -> None :
138- if not self .is_rank_0 :
139- return
140- put_queue (lambda : self ._create_checkpoint (
141- env .model , env .iteration , self ._filename , self ._frequency ))
142-
143- class TuneReportCheckpointCallback (_TuneLGBMRank0Mixin ,
144- OrigTuneReportCheckpointCallback ):
145- """Creates a callback that reports metrics and checkpoints model."""
146- order = 21
147-
148- _checkpoint_callback_cls = _TuneCheckpointCallback
149- _report_callback_cls = TuneReportCallback
150-
151- def __init__ (
152- self ,
153- metrics : Union [None , str , List [str ], Dict [str , str ]] = None ,
154- filename : str = "checkpoint" ,
155- frequency : int = 5 ):
156- self ._checkpoint = self ._checkpoint_callback_cls (
157- filename , frequency )
158- self ._report = self ._report_callback_cls (metrics )
159-
160- @property
161- def is_rank_0 (self ) -> bool :
162- try :
163- return self ._is_rank_0
164- except AttributeError :
165- return False
166-
167- @is_rank_0 .setter
168- def is_rank_0 (self , val : bool ):
169- self ._is_rank_0 = val
170- if hasattr (self , "_checkpoint" ):
171- self ._checkpoint .is_rank_0 = val
172- if hasattr (self , "_report" ):
173- self ._report .is_rank_0 = val
174-
175- def __call__ (self , env : CallbackEnv ) -> None :
176- self ._checkpoint (env )
177- self ._report (env )
49+ if TUNE_INSTALLED :
17850
179- elif TUNE_INSTALLED :
180- # New style callbacks.
18151 class TuneReportCallback (_TuneLGBMRank0Mixin , OrigTuneReportCallback ):
18252 def __call__ (self , env : CallbackEnv ) -> None :
18353 if not self .is_rank_0 :
@@ -241,15 +111,10 @@ def _try_add_tune_callback(kwargs: Dict):
241111 target = "lightgbm_ray.tune.TuneReportCallback" ))
242112 has_tune_callback = True
243113 elif isinstance (cb , OrigTuneReportCheckpointCallback ):
244- if TUNE_LEGACY :
245- replace_cb = TuneReportCheckpointCallback (
246- metrics = cb ._report ._metrics ,
247- filename = cb ._checkpoint ._filename )
248- else :
249- replace_cb = TuneReportCheckpointCallback (
250- metrics = cb ._report ._metrics ,
251- filename = cb ._checkpoint ._filename ,
252- frequency = cb ._checkpoint ._frequency )
114+ replace_cb = TuneReportCheckpointCallback (
115+ metrics = cb ._report ._metrics ,
116+ filename = cb ._checkpoint ._filename ,
117+ frequency = cb ._checkpoint ._frequency )
253118 new_callbacks .append (replace_cb )
254119 logging .warning (
255120 REPLACE_MSG .format (
0 commit comments