Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 131 additions & 54 deletions model_zoo/gpt-3/ppfleetx/core/engine/auto_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
from ppfleetx.utils.log import convert_timestamp_to_data, get_timestamp, logger
from ppfleetx.utils.version import version_check

def use_new_executor():
new_executor_micro_batching = os.environ.get(
'FLAGS_new_executor_micro_batching', None
)
return new_executor_micro_batching in [
1,
'1',
True,
'True',
'true',
]

class AutoEngine(BasicEngine):
def __init__(self, configs, module=None, mode="train"):
Expand Down Expand Up @@ -152,7 +163,10 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade

total_train_batch = self._max_steps if self._run_mode == "step" else len(train_data_loader)
total_train_step = self._max_steps if self._run_mode == "step" else total_train_batch * self._num_train_epochs
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
if use_new_executor():
total_eval_batch = len(valid_data_loader) if valid_data_loader is not None else 0
else:
total_eval_batch = valid_data_loader._steps if valid_data_loader is not None else 0
valid_data_loader = valid_data_loader if valid_data_loader is not None else None
eval_finished_step = 0

Expand All @@ -163,26 +177,41 @@ def _train_one_epoch(self, epoch_index, train_data_loader=None, valid_data_loade
if step < self._load_recovery["step"]:
continue

batches = self._validate_batch(batch)

fetch_list = None
if self._strategy.amp.enable:
# fetch_list = ["find_infinite_scale.tmp_0", "loss_scaling_0"]
fetch_list = []

final_loss = None
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None and self._accumulate_steps > 1:
final_loss /= self._accumulate_steps
if use_new_executor():
batches = self._validate_batch(batch)
for micro_batch in batches:
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(micro_batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None and self._accumulate_steps > 1:
final_loss /= self._accumulate_steps
else:
if self._pp_degree == 1 and self._accumulate_steps > 1: # gradient merge
local_steps = self._accumulate_steps
else:
local_steps = 1
for _ in range(local_steps):
with paddle.profiler.utils._nvprof_range(iter_id=step, start=self.nvprof_start, end=self.nvprof_end):
outs = self._auto_engine.run(batch, fetch_list=fetch_list, mode="train")
# pp: some devices don't have loss in outs
if "loss" in outs:
if final_loss is None:
final_loss = np.sum(outs["loss"])
else:
final_loss += np.sum(outs["loss"])

if final_loss is not None:
train_losses.append(final_loss)
Expand Down Expand Up @@ -267,27 +296,49 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):

train_data_loader, valid_data_loader = None, None
if train_dataset:
train_data_loader = self._auto_engine.dataloader(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
num_workers=1,
sample_split=train_dataset.sample_split,
mode="train",
)
if use_new_executor():
train_data_loader = self._auto_engine.dataloader(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
num_workers=1,
sample_split=train_dataset.sample_split,
mode="train",
)
else:
train_data_loader = self._auto_engine.dataloader_from_generator(
dataset=train_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=train_dataset.collate_fn,
sample_split=train_dataset.sample_split,
mode="train",
)
if valid_dataset and self._eval_freq <= self._max_steps:
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
if use_new_executor():
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
else:
valid_data_loader = self._auto_engine.dataloader_from_generator(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
sample_split=valid_dataset.sample_split,
mode="eval",
)

for epoch_index in range(start_epoch, epoch):
train_epoch_start = get_timestamp()
Expand Down Expand Up @@ -320,6 +371,8 @@ def fit(self, epoch=1, train_dataset=None, valid_dataset=None):
convert_timestamp_to_data(get_timestamp() - train_start)
)
)
if valid_data_loader and not use_new_executor():
valid_data_loader._inner_dataloader.reset()

if self.profiler:
self._profiler_done()
Expand All @@ -328,16 +381,28 @@ def evaluate(self, epoch=1, valid_dataset=None):

valid_data_loader = None
if valid_dataset:
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
if use_new_executor():
valid_data_loader = self._auto_engine.dataloader(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)
else:
valid_data_loader = self._auto_engine.dataloader_from_generator(
dataset=valid_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=valid_dataset.collate_fn,
num_workers=1,
sample_split=valid_dataset.sample_split,
mode="eval",
)

for epoch_index in range(epoch):
eval_epoch_start = get_timestamp()
Expand Down Expand Up @@ -388,16 +453,28 @@ def predict(self, epoch=1, test_dataset=None):

test_data_loader = None
if test_dataset:
test_data_loader = self._auto_engine.dataloader(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)
if use_new_executor():
test_data_loader = self._auto_engine.dataloader(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)
else:
test_data_loader = self._auto_engine.dataloader_from_generator(
dataset=test_dataset,
batch_size=self._global_batch_size,
steps_per_epoch=self._max_steps,
epochs=self._num_train_epochs,
collate_fn=test_dataset.collate_fn,
num_workers=1,
sample_split=test_dataset.sample_split,
mode="predict",
)

test_start = get_timestamp()
test_losses = []
Expand Down