Skip to content

Commit 321faf3

Browse files
authored
add if-else use_new_executor branch (#6897)
* add if-else use_new_executor branch * no pre-commit reformat * add () for use_new_executor * align old executor and new executor loss * fix outs[loss].shape == () can't get value by outs[loss][-1] * tiny format fix * restore the way loss is calculated by new exe * old executor doesn't have to / accumulate_steps * dataloader_from_generator add param num_workers=1
1 parent 27e25e6 commit 321faf3

File tree

1 file changed

+131
-54
lines changed

1 file changed

+131
-54
lines changed

model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py

Lines changed: 131 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,17 @@
3131
from ppfleetx.utils.log import convert_timestamp_to_data, get_timestamp, logger
3232
from ppfleetx.utils.version import version_check
3333

34+
def use_new_executor():
35+
new_executor_micro_batching = os.environ.get(
36+
'FLAGS_new_executor_micro_batching', None
37+
)
38+
return new_executor_micro_batching in [
39+
1,
40+
'1',
41+
True,
42+
'True',
43+
'true',
44+
]
3445

3546
class AutoEngine(BasicEngine):
3647
def __init__(self, configs, module=None, mode="train"):
@@ -152,7 +163,10 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
152163

153164
total_train_batch = self._max_steps if self._run_mode == "step" else len(train_data_loader)
154165
total_train_step = self._max_steps if self._run_mode == "step" else total_train_batch * self._num_train_epochs
155-
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
166+
if use_new_executor():
167+
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
168+
else:
169+
total_eval_batch = valid_data_loader._steps if valid_data_loader is not None else 0
156170
valid_data_loader = valid_data_loader if valid_data_loader is not None else None
157171
eval_finished_step = 0
158172

@@ -163,26 +177,41 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
163177
if step < self._load_recovery["step"]:
164178
continue
165179

166-
batches = self._validate_batch(batch)
167180

168181
fetch_list = None
169182
if self._strategy.amp.enable:
170183
# fetch_list = ["find_infinite_scale.tmp_0", "loss_scaling_0"]
171184
fetch_list = []
172185

173186
final_loss = None
174-
for micro_batch in batches:
175-
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
176-
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
177-
# pp: some devices don't have loss in outs
178-
if "loss" in outs:
179-
if final_loss is None:
180-
final_loss = np.sum(outs["loss"])
181-
else:
182-
final_loss += np.sum(outs["loss"])
183-
184-
if final_loss is not None and self._accumulate_steps > 1:
185-
final_loss /= self._accumulate_steps
187+
if use_new_executor():
188+
batches = self._validate_batch(batch)
189+
for micro_batch in batches:
190+
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
191+
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
192+
# pp: some devices don't have loss in outs
193+
if "loss" in outs:
194+
if final_loss is None:
195+
final_loss = np.sum(outs["loss"])
196+
else:
197+
final_loss += np.sum(outs["loss"])
198+
199+
if final_loss is not None and self._accumulate_steps > 1:
200+
final_loss /= self._accumulate_steps
201+
else:
202+
if self._pp_degree == 1 and self._accumulate_steps > 1: # gradient merge
203+
local_steps = self._accumulate_steps
204+
else:
205+
local_steps = 1
206+
for _ in range(local_steps):
207+
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
208+
outs = self._auto_engine.run(batch, fetch_list=fetch_list, mode="train")
209+
# pp: some devices don't have loss in outs
210+
if "loss" in outs:
211+
if final_loss is None:
212+
final_loss = np.sum(outs["loss"])
213+
else:
214+
final_loss += np.sum(outs["loss"])
186215

187216
if final_loss is not None:
188217
train_losses.append(final_loss)
@@ -267,27 +296,49 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
267296

268297
train_data_loader, valid_data_loader = None, None
269298
if train_dataset:
270-
train_data_loader = self._auto_engine.dataloader(
271-
dataset=train_dataset,
272-
batch_size=self._global_batch_size,
273-
steps_per_epoch=self._max_steps,
274-
epochs=self._num_train_epochs,
275-
collate_fn=train_dataset.collate_fn,
276-
num_workers=1,
277-
sample_split=train_dataset.sample_split,
278-
mode="train",
279-
)
299+
if use_new_executor():
300+
train_data_loader = self._auto_engine.dataloader(
301+
dataset=train_dataset,
302+
batch_size=self._global_batch_size,
303+
steps_per_epoch=self._max_steps,
304+
epochs=self._num_train_epochs,
305+
collate_fn=train_dataset.collate_fn,
306+
num_workers=1,
307+
sample_split=train_dataset.sample_split,
308+
mode="train",
309+
)
310+
else:
311+
train_data_loader = self._auto_engine.dataloader_from_generator(
312+
dataset=train_dataset,
313+
batch_size=self._global_batch_size,
314+
steps_per_epoch=self._max_steps,
315+
epochs=self._num_train_epochs,
316+
collate_fn=train_dataset.collate_fn,
317+
sample_split=train_dataset.sample_split,
318+
mode="train",
319+
)
280320
if valid_dataset and self._eval_freq <= self._max_steps:
281-
valid_data_loader = self._auto_engine.dataloader(
282-
dataset=valid_dataset,
283-
batch_size=self._global_batch_size,
284-
steps_per_epoch=self._max_steps,
285-
epochs=self._num_train_epochs,
286-
collate_fn=valid_dataset.collate_fn,
287-
num_workers=1,
288-
sample_split=valid_dataset.sample_split,
289-
mode="eval",
290-
)
321+
if use_new_executor():
322+
valid_data_loader = self._auto_engine.dataloader(
323+
dataset=valid_dataset,
324+
batch_size=self._global_batch_size,
325+
steps_per_epoch=self._max_steps,
326+
epochs=self._num_train_epochs,
327+
collate_fn=valid_dataset.collate_fn,
328+
num_workers=1,
329+
sample_split=valid_dataset.sample_split,
330+
mode="eval",
331+
)
332+
else:
333+
valid_data_loader = self._auto_engine.dataloader_from_generator(
334+
dataset=valid_dataset,
335+
batch_size=self._global_batch_size,
336+
steps_per_epoch=self._max_steps,
337+
epochs=self._num_train_epochs,
338+
collate_fn=valid_dataset.collate_fn,
339+
sample_split=valid_dataset.sample_split,
340+
mode="eval",
341+
)
291342

292343
for epoch_index in range(start_epoch, epoch):
293344
train_epoch_start = get_timestamp()
@@ -320,6 +371,8 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
320371
convert_timestamp_to_data(get_timestamp() - train_start)
321372
)
322373
)
374+
if valid_data_loader and not use_new_executor():
375+
valid_data_loader._inner_dataloader.reset()
323376

324377
if self.profiler:
325378
self._profiler_done()
@@ -328,16 +381,28 @@ def evaluate(self, epoch=1, valid_dataset=None):
328381

329382
valid_data_loader = None
330383
if valid_dataset:
331-
valid_data_loader = self._auto_engine.dataloader(
332-
dataset=valid_dataset,
333-
batch_size=self._global_batch_size,
334-
steps_per_epoch=self._max_steps,
335-
epochs=self._num_train_epochs,
336-
collate_fn=valid_dataset.collate_fn,
337-
num_workers=1,
338-
sample_split=valid_dataset.sample_split,
339-
mode="eval",
340-
)
384+
if use_new_executor():
385+
valid_data_loader = self._auto_engine.dataloader(
386+
dataset=valid_dataset,
387+
batch_size=self._global_batch_size,
388+
steps_per_epoch=self._max_steps,
389+
epochs=self._num_train_epochs,
390+
collate_fn=valid_dataset.collate_fn,
391+
num_workers=1,
392+
sample_split=valid_dataset.sample_split,
393+
mode="eval",
394+
)
395+
else:
396+
valid_data_loader = self._auto_engine.dataloader_from_generator(
397+
dataset=valid_dataset,
398+
batch_size=self._global_batch_size,
399+
steps_per_epoch=self._max_steps,
400+
epochs=self._num_train_epochs,
401+
collate_fn=valid_dataset.collate_fn,
402+
num_workers=1,
403+
sample_split=valid_dataset.sample_split,
404+
mode="eval",
405+
)
341406

342407
for epoch_index in range(epoch):
343408
eval_epoch_start = get_timestamp()
@@ -388,16 +453,28 @@ def predict(self, epoch=1, test_dataset=None):
388453

389454
test_data_loader = None
390455
if test_dataset:
391-
test_data_loader = self._auto_engine.dataloader(
392-
dataset=test_dataset,
393-
batch_size=self._global_batch_size,
394-
steps_per_epoch=self._max_steps,
395-
epochs=self._num_train_epochs,
396-
collate_fn=test_dataset.collate_fn,
397-
num_workers=1,
398-
sample_split=test_dataset.sample_split,
399-
mode="predict",
400-
)
456+
if use_new_executor():
457+
test_data_loader = self._auto_engine.dataloader(
458+
dataset=test_dataset,
459+
batch_size=self._global_batch_size,
460+
steps_per_epoch=self._max_steps,
461+
epochs=self._num_train_epochs,
462+
collate_fn=test_dataset.collate_fn,
463+
num_workers=1,
464+
sample_split=test_dataset.sample_split,
465+
mode="predict",
466+
)
467+
else:
468+
test_data_loader = self._auto_engine.dataloader_from_generator(
469+
dataset=test_dataset,
470+
batch_size=self._global_batch_size,
471+
steps_per_epoch=self._max_steps,
472+
epochs=self._num_train_epochs,
473+
collate_fn=test_dataset.collate_fn,
474+
num_workers=1,
475+
sample_split=test_dataset.sample_split,
476+
mode="predict",
477+
)
401478

402479
test_start = get_timestamp()
403480
test_losses = []

0 commit comments

Comments
 (0)