|
9 | 9 | from transformers.trainer_utils import IntervalStrategy, has_length
|
10 | 10 |
|
11 | 11 | from swift.trainers import TrainingArguments
|
| 12 | +from swift.utils import is_pai_training_job |
12 | 13 |
|
13 | 14 |
|
14 | 15 | class ProgressCallbackNew(ProgressCallback):
|
@@ -47,7 +48,7 @@ def on_log(self,
|
47 | 48 | for k, v in logs.items():
|
48 | 49 | if isinstance(v, float):
|
49 | 50 | logs[k] = round(logs[k], 8)
|
50 |
| - if state.is_local_process_zero and self.training_bar is not None: |
| 51 | + if not is_pai_training_job() and state.is_local_process_zero: |
51 | 52 | jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
|
52 | 53 | with open(jsonl_path, 'a', encoding='utf-8') as f:
|
53 | 54 | f.write(json.dumps(logs) + '\n')
|
@@ -77,7 +78,7 @@ def on_log(self, args, state, control, logs=None, **kwargs):
|
77 | 78 | for k, v in logs.items():
|
78 | 79 | if isinstance(v, float):
|
79 | 80 | logs[k] = round(logs[k], 8)
|
80 |
| - if state.is_local_process_zero: |
| 81 | + if not is_pai_training_job() and state.is_local_process_zero: |
81 | 82 | jsonl_path = os.path.join(args.output_dir, 'logging.jsonl')
|
82 | 83 | with open(jsonl_path, 'a', encoding='utf-8') as f:
|
83 | 84 | f.write(json.dumps(logs) + '\n')
|
|
0 commit comments