Skip to content

Commit 4ae2b03

Browse files
authored
[AutoNLP]optimize log (#5021)
* optimize log * fix test * fix * fix
1 parent 00d842b commit 4ae2b03

File tree

4 files changed

+38
-59
lines changed

4 files changed

+38
-59
lines changed

paddlenlp/experimental/autonlp/auto_trainer_base.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@
1313
# limitations under the License.
1414
import copy
1515
import datetime
16+
import logging
1617
import os
18+
import shutil
19+
import sys
1720
from abc import ABCMeta, abstractmethod
1821
from typing import Any, Callable, Dict, List, Optional, Union
1922

@@ -122,12 +125,36 @@ def _data_checks_and_inference(self, train_dataset: Dataset, eval_dataset: Datas
122125
Performs different data checks and inferences on the training and eval datasets
123126
"""
124127

125-
@abstractmethod
126-
def _construct_trainable(self, train_dataset: Dataset, eval_dataset: Dataset) -> Callable:
128+
def _construct_trainable(self) -> Callable:
127129
"""
128130
Returns the Trainable functions that contains the main preprocessing and training logic
129131
"""
130132

133+
def trainable(model_config):
134+
# import is required for proper pickling
135+
from paddlenlp.utils.log import logger
136+
137+
stdout_handler = logging.StreamHandler(sys.stdout)
138+
stdout_handler.setFormatter(logger.format)
139+
logger.logger.addHandler(stdout_handler)
140+
141+
# construct trainer
142+
model_config = model_config["candidates"]
143+
trainer = self._construct_trainer(model_config)
144+
# train
145+
trainer.train()
146+
# evaluate
147+
eval_metrics = trainer.evaluate()
148+
# save dygraph model
149+
trainer.save_model(self.save_path)
150+
151+
if os.path.exists(self.training_path):
152+
logger.info("Removing training checkpoints to conserve disk space")
153+
shutil.rmtree(self.training_path)
154+
return eval_metrics
155+
156+
return trainable
157+
131158
@abstractmethod
132159
def _compute_metrics(self, eval_preds: EvalPrediction) -> Dict[str, float]:
133160
"""
@@ -325,9 +352,9 @@ def train(
325352
tune_config=tune_config,
326353
run_config=RunConfig(
327354
name=experiment_name,
328-
log_to_file=True,
355+
log_to_file="train.log",
329356
local_dir=self.output_dir if self.output_dir else None,
330-
callbacks=[tune.logger.CSVLoggerCallback(), tune.logger.JsonLoggerCallback()],
357+
callbacks=[tune.logger.CSVLoggerCallback()],
331358
),
332359
)
333360
self.training_results = self.tuner.fit()

paddlenlp/experimental/autonlp/text_classification.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import os
1818
import shutil
19-
from typing import Any, Callable, Dict, List, Optional
19+
from typing import Any, Dict, List, Optional
2020

2121
import numpy as np
2222
import paddle
@@ -372,32 +372,6 @@ def _construct_trainer(self, model_config) -> Trainer:
372372
raise NotImplementedError("'trainer_type' can only be one of ['Trainer', 'PromptTrainer']")
373373
return trainer
374374

375-
def _construct_trainable(self) -> Callable:
376-
"""
377-
Returns the Trainable functions that contains the main preprocessing and training logic
378-
"""
379-
380-
def trainable(model_config):
381-
# import is required for proper pickling
382-
from paddlenlp.utils.log import logger
383-
384-
# construct trainer
385-
model_config = model_config["candidates"]
386-
trainer = self._construct_trainer(model_config)
387-
# train
388-
trainer.train()
389-
# evaluate
390-
eval_metrics = trainer.evaluate()
391-
# save dygraph model
392-
trainer.save_model(self.save_path)
393-
394-
if os.path.exists(self.training_path):
395-
logger.info("Removing training checkpoints to conserve disk space")
396-
shutil.rmtree(self.training_path)
397-
return eval_metrics
398-
399-
return trainable
400-
401375
def evaluate(self, eval_dataset: Optional[Dataset] = None, trial_id: Optional[str] = None):
402376
"""
403377
Run evaluation and returns metrics from a certain `trial_id` on the given dataset.

paddlenlp/trainer/integrations.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -156,25 +156,6 @@ def on_evaluate(self, args, state, control, **kwargs):
156156
if self.tune.is_session_enabled() and metrics is not None and isinstance(metrics, dict):
157157
self.session.report(metrics)
158158

159-
# report session metrics to Ray to track trial progress
160-
def on_epoch_end(self, args, state, control, **kwargs):
161-
if not state.is_world_process_zero:
162-
return
163-
164-
metrics = kwargs.get("metrics", None)
165-
if self.tune.is_session_enabled() and metrics is not None and isinstance(metrics, dict):
166-
self.session.report(metrics)
167-
168-
# forward trainer logs
169-
def on_log(self, args, state, control, logs=None, **kwargs):
170-
if not state.is_world_process_zero:
171-
return
172-
173-
if logs is not None:
174-
# In AutoNLP's Ray setup, we pipe stdout to a stdout file for logging purposes
175-
# TODO: find a better way for this
176-
print(logs)
177-
178159

179160
INTEGRATION_TO_CALLBACK = {
180161
"visualdl": VisualDLCallback,

tests/experimental/autonlp/test_text_classification.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,12 @@ def test_multiclass(self, custom_model_candidate, hp_overrides):
135135
self.assertEqual(len(results_df), num_models)
136136

137137
# test hp override
138+
model_result = auto_trainer._get_model_result()
138139
if hp_overrides is not None:
139140
for hp_key, hp_value in hp_overrides.items():
140-
result_hp_key = f"config/candidates/{hp_key}"
141-
self.assertEqual(results_df[result_hp_key][0], hp_value)
141+
self.assertEqual(model_result.metrics["config"]["candidates"][hp_key], hp_value)
142142

143143
# test save
144-
model_result = auto_trainer._get_model_result()
145144
trainer_type = model_result.metrics["config"]["candidates"]["trainer_type"]
146145
save_path = os.path.join(model_result.log_dir, auto_trainer.save_path)
147146
self.assertTrue(os.path.exists(os.path.join(save_path, "model_state.pdparams")))
@@ -247,13 +246,12 @@ def test_multilabel(self, custom_model_candidate, hp_overrides):
247246
self.assertEqual(len(results_df), num_models)
248247

249248
# test hp override
249+
model_result = auto_trainer._get_model_result()
250250
if hp_overrides is not None:
251251
for hp_key, hp_value in hp_overrides.items():
252-
result_hp_key = f"config/candidates/{hp_key}"
253-
self.assertEqual(results_df[result_hp_key][0], hp_value)
252+
self.assertEqual(model_result.metrics["config"]["candidates"][hp_key], hp_value)
254253

255254
# test save
256-
model_result = auto_trainer._get_model_result()
257255
trainer_type = model_result.metrics["config"]["candidates"]["trainer_type"]
258256
save_path = os.path.join(model_result.log_dir, auto_trainer.save_path)
259257
self.assertTrue(os.path.exists(os.path.join(save_path, "model_state.pdparams")))
@@ -358,13 +356,12 @@ def test_default_model_candidate(self, language, hp_overrides):
358356
self.assertEqual(len(results_df), num_models)
359357

360358
# test hp override
359+
model_result = auto_trainer._get_model_result()
361360
if hp_overrides is not None:
362361
for hp_key, hp_value in hp_overrides.items():
363-
result_hp_key = f"config/candidates/{hp_key}"
364-
self.assertEqual(results_df[result_hp_key][0], hp_value)
362+
self.assertEqual(model_result.metrics["config"]["candidates"][hp_key], hp_value)
365363

366364
# test save
367-
model_result = auto_trainer._get_model_result()
368365
trainer_type = model_result.metrics["config"]["candidates"]["trainer_type"]
369366
save_path = os.path.join(model_result.log_dir, auto_trainer.save_path)
370367
self.assertTrue(os.path.exists(os.path.join(save_path, "model_state.pdparams")))

0 commit comments

Comments
 (0)