Skip to content

Commit f9021f8

Browse files
[PIR][Prim] support PIR train (#7276)
* support PIR train * polish code
1 parent f9d8e73 commit f9021f8

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

model_zoo/gpt-3/ppfleetx/configs/nlp/gpt/pretrain_gpt_345M_single_card.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Global:
44
global_batch_size:
55
local_batch_size: 8
66
micro_batch_size: 8
7+
to_static: False
78

89

910
Model:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import sys
17+
18+
import paddle
19+
import time
20+
21+
__dir__ = os.path.dirname(os.path.abspath(__file__))
22+
sys.path.append(os.path.abspath(os.path.join(__dir__, "../../")))
23+
24+
from ppfleetx.data import build_dataloader
25+
from ppfleetx.distributed.apis import env
26+
from ppfleetx.models import build_module
27+
from ppfleetx.optims import build_lr_scheduler, build_optimizer
28+
from ppfleetx.utils import config
29+
30+
31+
class MovingAverage:
32+
def __init__(self):
33+
self.sum = 0
34+
self.val = [0] * self.window_size
35+
self.cnt = 0
36+
37+
def update(self, val, n):
38+
self.cnt = min(self.cnt + n, self.window_size)
39+
offset = max(self.window_size - n, 0)
40+
self.sum -= sum(self.values[:-offset])
41+
self.sum = val * min(n, self.window_size)
42+
self.avg = self.sum / self.cnt
43+
44+
45+
def main():
46+
args = config.parse_args()
47+
cfg = config.get_config(args.config, overrides=args.override, show=False)
48+
paddle.device.set_device("gpu:0")
49+
env.set_seed(cfg.Global.seed)
50+
module = build_module(cfg)
51+
config.print_config(cfg)
52+
53+
amp_config = cfg.Engine.mix_precision
54+
scale_loss = amp_config["scale_loss"]
55+
56+
scaler = paddle.amp.GradScaler(init_loss_scaling=scale_loss)
57+
58+
train_data_loader = build_dataloader(cfg.Data, "Train")
59+
60+
enable_to_static = cfg.Global.to_static
61+
if str(enable_to_static).lower() == "true":
62+
model = paddle.jit.to_static(module.model)
63+
else:
64+
model = module.model
65+
66+
cfg.Optimizer.lr.update(
67+
{
68+
"epochs": cfg.Engine.num_train_epochs,
69+
"step_each_epoch": len(train_data_loader),
70+
"total_steps": cfg.Engine.max_steps,
71+
}
72+
)
73+
lr_scheduler = build_lr_scheduler(cfg.Optimizer.lr)
74+
optimizer = build_optimizer(cfg.Optimizer, model, lr_scheduler)
75+
76+
global_batch_size = cfg.Global.global_batch_size
77+
max_steps = cfg.Engine.max_steps
78+
for step, batch in enumerate(train_data_loader()):
79+
if step <= max_steps:
80+
init_time = time.time()
81+
tokens, position_ids, labels, loss_mask = batch
82+
83+
preds = model(tokens, position_ids)
84+
loss = module.loss_fn(preds, labels, loss_mask)
85+
86+
loss.backward()
87+
optimizer.step()
88+
89+
optimizer.clear_grad()
90+
lr_scheduler.step(global_batch_size)
91+
after_time = time.time()
92+
during_time = after_time - init_time
93+
94+
print(
95+
"step: %d/%d\t" % (step, max_steps),
96+
"loss:%.6f\t" % loss.numpy(),
97+
"lr:%.6g\t" % optimizer.get_lr(),
98+
"loss_scale:%.1f\t" % scaler._scale.numpy(),
99+
"batch time: %.4f s" % (during_time),
100+
)
101+
102+
103+
if __name__ == "__main__":
104+
main()

0 commit comments

Comments
 (0)